• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
5  * except in compliance with the License. You may obtain a copy of the License at
6  *
7  *      http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software distributed under the
10  * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11  * KIND, either express or implied. See the License for the specific language governing
12  * permissions and limitations under the License.
13  */
14 package com.android.systemui.plugins.processor
15 
16 import com.android.systemui.plugins.annotations.GeneratedImport
17 import com.android.systemui.plugins.annotations.ProtectedInterface
18 import com.android.systemui.plugins.annotations.ProtectedReturn
19 import com.android.systemui.plugins.annotations.SimpleProperty
20 import com.google.auto.service.AutoService
21 import javax.annotation.processing.AbstractProcessor
22 import javax.annotation.processing.ProcessingEnvironment
23 import javax.annotation.processing.RoundEnvironment
24 import javax.lang.model.element.Element
25 import javax.lang.model.element.ElementKind
26 import javax.lang.model.element.ExecutableElement
27 import javax.lang.model.element.Modifier
28 import javax.lang.model.element.PackageElement
29 import javax.lang.model.element.TypeElement
30 import javax.lang.model.type.TypeKind
31 import javax.lang.model.type.TypeMirror
32 import javax.tools.Diagnostic.Kind
33 import kotlin.collections.ArrayDeque
34 
35 /**
36  * [ProtectedPluginProcessor] generates a proxy implementation for interfaces annotated with
37  * [ProtectedInterface] which catches [Exception]s generated by the proxied target. Production
38  * plugin interfaces should use this to catch [LinkagError]s as that protects the plugin host from
39  * crashing due to out-of-date plugin code, where some call has changed so that the [ClassLoader] is
40  * no longer able to resolve it correctly.
41  *
42  * [PluginInstance] observes these failures via [ProtectedMethodListener] and unloads the plugin in
43  * question to prevent further issues. This persists through further load/unload requests.
44  *
45  * To centralize access to the proxy types, an additional type [PluginProtector] is also generated.
46  * This class provides static methods which wrap an instance of the target interface in the proxy
47  * type if it is not already an instance of the proxy.
48  */
49 @AutoService(ProtectedPluginProcessor::class)
50 class ProtectedPluginProcessor : AbstractProcessor() {
51     private lateinit var procEnv: ProcessingEnvironment
52 
initnull53     override fun init(procEnv: ProcessingEnvironment) {
54         this.procEnv = procEnv
55     }
56 
getSupportedAnnotationTypesnull57     override fun getSupportedAnnotationTypes(): Set<String> =
58         setOf("com.android.systemui.plugins.annotations.ProtectedInterface")
59 
60     private data class TargetData(
61         val attribute: TypeElement,
62         val sourceType: Element,
63         val sourcePkg: String,
64         val sourceName: String,
65         val outputName: String,
66         val exTypeAttr: ProtectedInterface,
67     )
68 
69     override fun process(annotations: Set<TypeElement>, roundEnv: RoundEnvironment): Boolean {
70         val targets = mutableMapOf<String, TargetData>() // keyed by fully-qualified source name
71         val additionalImports = mutableSetOf<String>()
72         for (attr in annotations) {
73             for (target in roundEnv.getElementsAnnotatedWith(attr)) {
74                 // Find the target exception types to be used
75                 var exTypeAttr = target.getAnnotation(ProtectedInterface::class.java)
76                 if (exTypeAttr == null || exTypeAttr.exTypes.size == 0) {
77                     exTypeAttr = ProtectedInterface.Default
78                 }
79 
80                 val sourceName = "${target.simpleName}"
81                 val outputName = "${sourceName}Protector"
82                 val pkg = (target.getEnclosingElement() as PackageElement).qualifiedName.toString()
83                 targets.put(
84                     "$target",
85                     TargetData(attr, target, pkg, sourceName, outputName, exTypeAttr),
86                 )
87 
88                 // This creates excessive imports, but it should be fine
89                 additionalImports.add("$pkg.$sourceName")
90                 additionalImports.add("$pkg.$outputName")
91             }
92         }
93 
94         if (targets.size <= 0) return false
95         for ((_, sourceType, sourcePkg, sourceName, outputName, exTypeAttr) in targets.values) {
96             // Find all methods in this type and all super types to that need to be implemented
97             val types = ArrayDeque<TypeMirror>().apply { addLast(sourceType.asType()) }
98             val impAttrs = mutableListOf<GeneratedImport>()
99             val methods = mutableListOf<ExecutableElement>()
100             while (types.size > 0) {
101                 val typeMirror = types.removeLast()
102                 if (typeMirror.toString() == "java.lang.Object") continue
103                 val type = procEnv.typeUtils.asElement(typeMirror)
104                 for (member in type.enclosedElements) {
105                     if (member.kind != ElementKind.METHOD) continue
106                     methods.add(member as ExecutableElement)
107                 }
108 
109                 impAttrs.addAll(type.getAnnotationsByType(GeneratedImport::class.java))
110                 types.addAll(procEnv.typeUtils.directSupertypes(typeMirror))
111             }
112 
113             val file = procEnv.filer.createSourceFile("$outputName")
114             TabbedWriter.writeTo(file.openWriter()) {
115                 line("package $sourcePkg;")
116                 line()
117 
118                 // Imports used by the proxy implementation
119                 line("import android.util.Log;")
120                 line("import com.android.systemui.plugins.PluginWrapper;")
121                 line("import com.android.systemui.plugins.ProtectedPluginListener;")
122                 line()
123 
124                 // Imports of other generated types
125                 if (additionalImports.size > 0) {
126                     for (impTarget in additionalImports) {
127                         line("import $impTarget;")
128                     }
129                     line()
130                 }
131 
132                 // Imports of caught exceptions
133                 if (exTypeAttr.exTypes.size > 0) {
134                     for (exType in exTypeAttr.exTypes) {
135                         line("import $exType;")
136                     }
137                     line()
138                 }
139 
140                 // Imports declared via @GeneratedImport
141                 if (impAttrs.size > 0) {
142                     for (impAttr in impAttrs) {
143                         line("import ${impAttr.extraImport};")
144                     }
145                     line()
146                 }
147 
148                 val interfaces = "$sourceName, PluginWrapper<$sourceName>"
149                 braceBlock("public class $outputName implements $interfaces") {
150                     line("private static final String CLASS = \"$sourceName\";")
151 
152                     // Static factory method to prevent wrapping the same object twice
153                     parenBlock("public static $outputName protect") {
154                         line("$sourceName instance,")
155                         line("ProtectedPluginListener listener")
156                     }
157                     braceBlock {
158                         line("if (instance instanceof $outputName)")
159                         line("    return ($outputName)instance;")
160                         line("return new $outputName(instance, listener);")
161                     }
162                     line()
163 
164                     // Member Fields
165                     line("private $sourceName mInstance;")
166                     line("private ProtectedPluginListener mListener;")
167                     line("private boolean mHasError = false;")
168                     line()
169 
170                     // Constructor
171                     parenBlock("private $outputName") {
172                         line("$sourceName instance,")
173                         line("ProtectedPluginListener listener")
174                     }
175                     braceBlock {
176                         line("mInstance = instance;")
177                         line("mListener = listener;")
178                     }
179                     line()
180 
181                     // Wrapped instance getter for version checker
182                     braceBlock("public $sourceName getPlugin()") { line("return mInstance;") }
183 
184                     // Method implementations
185                     for (method in methods) {
186                         val methodName = method.simpleName
187                         if (methods.any { methodName.startsWith("${it.simpleName}\$") }) {
188                             continue
189                         }
190                         val returnTypeName = method.returnType.toString()
191                         val callArgs = StringBuilder()
192                         var isFirst = true
193                         val isStatic = method.modifiers.contains(Modifier.STATIC)
194 
195                         if (!isStatic) {
196                             line("@Override")
197                         }
198                         parenBlock("public $returnTypeName $methodName") {
199                             // While copying the method signature for the proxy type, we also
200                             // accumulate arguments for the nested callsite.
201                             for (param in method.parameters) {
202                                 if (!isFirst) completeLine(",")
203                                 startLine("${param.asType()} ${param.simpleName}")
204                                 isFirst = false
205 
206                                 if (callArgs.length > 0) callArgs.append(", ")
207                                 callArgs.append(param.simpleName)
208                             }
209                         }
210 
211                         val isVoid = method.returnType.kind == TypeKind.VOID
212                         val methodContainer = if (isStatic) sourceName else "mInstance"
213                         val nestedCall = "$methodContainer.$methodName($callArgs)"
214                         val callStatement =
215                             when {
216                                 isVoid -> "$nestedCall;"
217                                 targets.containsKey(returnTypeName) -> {
218                                     val targetType = targets.get(returnTypeName)!!.outputName
219                                     "return $targetType.protect($nestedCall, mListener);"
220                                 }
221                                 else -> "return $nestedCall;"
222                             }
223 
224                         // Simple property methods forgo protection
225                         val simpleAttr = method.getAnnotation(SimpleProperty::class.java)
226                         if (simpleAttr != null) {
227                             braceBlock {
228                                 line("final String METHOD = \"$methodName\";")
229                                 line(callStatement)
230                             }
231                             line()
232                             continue
233                         }
234 
235                         // Standard implementation wraps nested call in try-catch
236                         braceBlock {
237                             val retAttr = method.getAnnotation(ProtectedReturn::class.java)
238                             val errorStatement =
239                                 when {
240                                     retAttr != null -> retAttr.statement
241                                     isVoid -> "return;"
242                                     else -> {
243                                         // Non-void methods must be annotated.
244                                         procEnv.messager.printMessage(
245                                             Kind.ERROR,
246                                             "$outputName.$methodName must be annotated with " +
247                                                 "@ProtectedReturn or @SimpleProperty",
248                                         )
249                                         "throw ex;"
250                                     }
251                                 }
252 
253                             line("final String METHOD = \"$methodName\";")
254 
255                             // Return immediately if any previous call has failed.
256                             braceBlock("if (mHasError)") { line(errorStatement) }
257 
258                             // Protect callsite in try/catch block
259                             braceBlock("try") { line(callStatement) }
260 
261                             // Notify listener when a target exception is caught
262                             for (exType in exTypeAttr.exTypes) {
263                                 val simpleName = exType.substringAfterLast(".")
264                                 braceBlock("catch ($simpleName ex)") {
265                                     line("Log.wtf(CLASS, \"Failed to execute: \" + METHOD, ex);")
266                                     line("mHasError = mListener.onFail(CLASS, METHOD, ex);")
267                                     line(errorStatement)
268                                 }
269                             }
270                         }
271                         line()
272                     }
273                 }
274             }
275         }
276 
277         // Write a centralized static factory type to its own file. This is for convience so that
278         // PluginInstance need not resolve each generated type at runtime as plugins are loaded.
279         val factoryFile = procEnv.filer.createSourceFile("PluginProtector")
280         TabbedWriter.writeTo(factoryFile.openWriter()) {
281             line("package com.android.systemui.plugins;")
282             line()
283 
284             line("import java.util.Map;")
285             line("import java.util.ArrayList;")
286             line("import java.util.HashSet;")
287             line("import android.util.Log;")
288             line("import static java.util.Map.entry;")
289             line()
290 
291             for (impTarget in additionalImports) {
292                 line("import $impTarget;")
293             }
294             line()
295 
296             braceBlock("public final class PluginProtector") {
297                 line("private PluginProtector() { }")
298                 line()
299 
300                 line("private static final String TAG = \"PluginProtector\";")
301                 line()
302 
303                 // Untyped factory SAM, private to this type.
304                 braceBlock("private interface Factory") {
305                     line("Object create(Object plugin, ProtectedPluginListener listener);")
306                 }
307                 line()
308 
309                 // Store a reference to each `protect` method in a map by interface type.
310                 parenBlock("private static final Map<Class, Factory> sFactories = Map.ofEntries") {
311                     var isFirst = true
312                     for (target in targets.values) {
313                         if (!isFirst) completeLine(",")
314                         target.apply {
315                             startLine("entry($sourceName.class, ")
316                             appendLine("(p, h) -> $outputName.protect(($sourceName)p, h))")
317                         }
318                         isFirst = false
319                     }
320                 }
321                 completeLine(";")
322                 line()
323 
324                 // Lookup the relevant factory based on the instance type, if not found return null.
325                 parenBlock("public static <T> T tryProtect") {
326                     line("T target,")
327                     line("ProtectedPluginListener listener")
328                 }
329                 braceBlock {
330                     // Accumulate interfaces from type and all base types
331                     line("HashSet<Class> interfaces = new HashSet<Class>();")
332                     line("Class current = target.getClass();")
333                     braceBlock("while (current != null)") {
334                         braceBlock("for (Class cls : current.getInterfaces())") {
335                             line("interfaces.add(cls);")
336                         }
337                         line("current = current.getSuperclass();")
338                     }
339                     line()
340 
341                     // Check if any of the interfaces are marked protectable
342                     line("int candidateCount = 0;")
343                     line("Factory candidateFactory = null;")
344                     braceBlock("for (Class cls : interfaces)") {
345                         line("Factory factory = sFactories.get(cls);")
346                         braceBlock("if (factory != null)") {
347                             line("candidateFactory = factory;")
348                             line("candidateCount++;")
349                         }
350                     }
351                     line()
352 
353                     // No match, return null
354                     braceBlock("if (candidateFactory == null)") {
355                         line("Log.i(TAG, \"Wasn't able to wrap \" + target);")
356                         line("return null;")
357                     }
358 
359                     // Multiple matches, not supported
360                     braceBlock("if (candidateCount >= 2)") {
361                         var error = "Plugin implements more than one protected interface"
362                         line("throw new UnsupportedOperationException(\"$error\");")
363                     }
364 
365                     // Call the factory and wrap the target object
366                     line("return (T)candidateFactory.create(target, listener);")
367                 }
368                 line()
369 
370                 // Wraps the target with the appropriate generated proxy if it exists.
371                 parenBlock("public static <T> T protectIfAble") {
372                     line("T target,")
373                     line("ProtectedPluginListener listener")
374                 }
375                 braceBlock {
376                     line("T result = tryProtect(target, listener);")
377                     line("return result != null ? result : target;")
378                 }
379                 line()
380             }
381         }
382 
383         return true
384     }
385 }
386