1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file 5 * except in compliance with the License. You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software distributed under the 10 * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 11 * KIND, either express or implied. See the License for the specific language governing 12 * permissions and limitations under the License. 13 */ 14 package com.android.systemui.plugins.processor 15 16 import com.android.systemui.plugins.annotations.GeneratedImport 17 import com.android.systemui.plugins.annotations.ProtectedInterface 18 import com.android.systemui.plugins.annotations.ProtectedReturn 19 import com.android.systemui.plugins.annotations.SimpleProperty 20 import com.google.auto.service.AutoService 21 import javax.annotation.processing.AbstractProcessor 22 import javax.annotation.processing.ProcessingEnvironment 23 import javax.annotation.processing.RoundEnvironment 24 import javax.lang.model.element.Element 25 import javax.lang.model.element.ElementKind 26 import javax.lang.model.element.ExecutableElement 27 import javax.lang.model.element.Modifier 28 import javax.lang.model.element.PackageElement 29 import javax.lang.model.element.TypeElement 30 import javax.lang.model.type.TypeKind 31 import javax.lang.model.type.TypeMirror 32 import javax.tools.Diagnostic.Kind 33 import kotlin.collections.ArrayDeque 34 35 /** 36 * [ProtectedPluginProcessor] generates a proxy implementation for interfaces annotated with 37 * [ProtectedInterface] which catches [Exception]s generated by the proxied target. Production 38 * plugin interfaces should use this to catch [LinkagError]s as that protects the plugin host from 39 * crashing due to out-of-date plugin code, where some call has changed so that the [ClassLoader] is 40 * no longer able to resolve it correctly. 41 * 42 * [PluginInstance] observes these failures via [ProtectedMethodListener] and unloads the plugin in 43 * question to prevent further issues. This persists through further load/unload requests. 44 * 45 * To centralize access to the proxy types, an additional type [PluginProtector] is also generated. 46 * This class provides static methods which wrap an instance of the target interface in the proxy 47 * type if it is not already an instance of the proxy. 48 */ 49 @AutoService(ProtectedPluginProcessor::class) 50 class ProtectedPluginProcessor : AbstractProcessor() { 51 private lateinit var procEnv: ProcessingEnvironment 52 initnull53 override fun init(procEnv: ProcessingEnvironment) { 54 this.procEnv = procEnv 55 } 56 getSupportedAnnotationTypesnull57 override fun getSupportedAnnotationTypes(): Set<String> = 58 setOf("com.android.systemui.plugins.annotations.ProtectedInterface") 59 60 private data class TargetData( 61 val attribute: TypeElement, 62 val sourceType: Element, 63 val sourcePkg: String, 64 val sourceName: String, 65 val outputName: String, 66 val exTypeAttr: ProtectedInterface, 67 ) 68 69 override fun process(annotations: Set<TypeElement>, roundEnv: RoundEnvironment): Boolean { 70 val targets = mutableMapOf<String, TargetData>() // keyed by fully-qualified source name 71 val additionalImports = mutableSetOf<String>() 72 for (attr in annotations) { 73 for (target in roundEnv.getElementsAnnotatedWith(attr)) { 74 // Find the target exception types to be used 75 var exTypeAttr = target.getAnnotation(ProtectedInterface::class.java) 76 if (exTypeAttr == null || exTypeAttr.exTypes.size == 0) { 77 exTypeAttr = ProtectedInterface.Default 78 } 79 80 val sourceName = "${target.simpleName}" 81 val outputName = "${sourceName}Protector" 82 val pkg = (target.getEnclosingElement() as PackageElement).qualifiedName.toString() 83 targets.put( 84 "$target", 85 TargetData(attr, target, pkg, sourceName, outputName, exTypeAttr), 86 ) 87 88 // This creates excessive imports, but it should be fine 89 additionalImports.add("$pkg.$sourceName") 90 additionalImports.add("$pkg.$outputName") 91 } 92 } 93 94 if (targets.size <= 0) return false 95 for ((_, sourceType, sourcePkg, sourceName, outputName, exTypeAttr) in targets.values) { 96 // Find all methods in this type and all super types to that need to be implemented 97 val types = ArrayDeque<TypeMirror>().apply { addLast(sourceType.asType()) } 98 val impAttrs = mutableListOf<GeneratedImport>() 99 val methods = mutableListOf<ExecutableElement>() 100 while (types.size > 0) { 101 val typeMirror = types.removeLast() 102 if (typeMirror.toString() == "java.lang.Object") continue 103 val type = procEnv.typeUtils.asElement(typeMirror) 104 for (member in type.enclosedElements) { 105 if (member.kind != ElementKind.METHOD) continue 106 methods.add(member as ExecutableElement) 107 } 108 109 impAttrs.addAll(type.getAnnotationsByType(GeneratedImport::class.java)) 110 types.addAll(procEnv.typeUtils.directSupertypes(typeMirror)) 111 } 112 113 val file = procEnv.filer.createSourceFile("$outputName") 114 TabbedWriter.writeTo(file.openWriter()) { 115 line("package $sourcePkg;") 116 line() 117 118 // Imports used by the proxy implementation 119 line("import android.util.Log;") 120 line("import com.android.systemui.plugins.PluginWrapper;") 121 line("import com.android.systemui.plugins.ProtectedPluginListener;") 122 line() 123 124 // Imports of other generated types 125 if (additionalImports.size > 0) { 126 for (impTarget in additionalImports) { 127 line("import $impTarget;") 128 } 129 line() 130 } 131 132 // Imports of caught exceptions 133 if (exTypeAttr.exTypes.size > 0) { 134 for (exType in exTypeAttr.exTypes) { 135 line("import $exType;") 136 } 137 line() 138 } 139 140 // Imports declared via @GeneratedImport 141 if (impAttrs.size > 0) { 142 for (impAttr in impAttrs) { 143 line("import ${impAttr.extraImport};") 144 } 145 line() 146 } 147 148 val interfaces = "$sourceName, PluginWrapper<$sourceName>" 149 braceBlock("public class $outputName implements $interfaces") { 150 line("private static final String CLASS = \"$sourceName\";") 151 152 // Static factory method to prevent wrapping the same object twice 153 parenBlock("public static $outputName protect") { 154 line("$sourceName instance,") 155 line("ProtectedPluginListener listener") 156 } 157 braceBlock { 158 line("if (instance instanceof $outputName)") 159 line(" return ($outputName)instance;") 160 line("return new $outputName(instance, listener);") 161 } 162 line() 163 164 // Member Fields 165 line("private $sourceName mInstance;") 166 line("private ProtectedPluginListener mListener;") 167 line("private boolean mHasError = false;") 168 line() 169 170 // Constructor 171 parenBlock("private $outputName") { 172 line("$sourceName instance,") 173 line("ProtectedPluginListener listener") 174 } 175 braceBlock { 176 line("mInstance = instance;") 177 line("mListener = listener;") 178 } 179 line() 180 181 // Wrapped instance getter for version checker 182 braceBlock("public $sourceName getPlugin()") { line("return mInstance;") } 183 184 // Method implementations 185 for (method in methods) { 186 val methodName = method.simpleName 187 if (methods.any { methodName.startsWith("${it.simpleName}\$") }) { 188 continue 189 } 190 val returnTypeName = method.returnType.toString() 191 val callArgs = StringBuilder() 192 var isFirst = true 193 val isStatic = method.modifiers.contains(Modifier.STATIC) 194 195 if (!isStatic) { 196 line("@Override") 197 } 198 parenBlock("public $returnTypeName $methodName") { 199 // While copying the method signature for the proxy type, we also 200 // accumulate arguments for the nested callsite. 201 for (param in method.parameters) { 202 if (!isFirst) completeLine(",") 203 startLine("${param.asType()} ${param.simpleName}") 204 isFirst = false 205 206 if (callArgs.length > 0) callArgs.append(", ") 207 callArgs.append(param.simpleName) 208 } 209 } 210 211 val isVoid = method.returnType.kind == TypeKind.VOID 212 val methodContainer = if (isStatic) sourceName else "mInstance" 213 val nestedCall = "$methodContainer.$methodName($callArgs)" 214 val callStatement = 215 when { 216 isVoid -> "$nestedCall;" 217 targets.containsKey(returnTypeName) -> { 218 val targetType = targets.get(returnTypeName)!!.outputName 219 "return $targetType.protect($nestedCall, mListener);" 220 } 221 else -> "return $nestedCall;" 222 } 223 224 // Simple property methods forgo protection 225 val simpleAttr = method.getAnnotation(SimpleProperty::class.java) 226 if (simpleAttr != null) { 227 braceBlock { 228 line("final String METHOD = \"$methodName\";") 229 line(callStatement) 230 } 231 line() 232 continue 233 } 234 235 // Standard implementation wraps nested call in try-catch 236 braceBlock { 237 val retAttr = method.getAnnotation(ProtectedReturn::class.java) 238 val errorStatement = 239 when { 240 retAttr != null -> retAttr.statement 241 isVoid -> "return;" 242 else -> { 243 // Non-void methods must be annotated. 244 procEnv.messager.printMessage( 245 Kind.ERROR, 246 "$outputName.$methodName must be annotated with " + 247 "@ProtectedReturn or @SimpleProperty", 248 ) 249 "throw ex;" 250 } 251 } 252 253 line("final String METHOD = \"$methodName\";") 254 255 // Return immediately if any previous call has failed. 256 braceBlock("if (mHasError)") { line(errorStatement) } 257 258 // Protect callsite in try/catch block 259 braceBlock("try") { line(callStatement) } 260 261 // Notify listener when a target exception is caught 262 for (exType in exTypeAttr.exTypes) { 263 val simpleName = exType.substringAfterLast(".") 264 braceBlock("catch ($simpleName ex)") { 265 line("Log.wtf(CLASS, \"Failed to execute: \" + METHOD, ex);") 266 line("mHasError = mListener.onFail(CLASS, METHOD, ex);") 267 line(errorStatement) 268 } 269 } 270 } 271 line() 272 } 273 } 274 } 275 } 276 277 // Write a centralized static factory type to its own file. This is for convience so that 278 // PluginInstance need not resolve each generated type at runtime as plugins are loaded. 279 val factoryFile = procEnv.filer.createSourceFile("PluginProtector") 280 TabbedWriter.writeTo(factoryFile.openWriter()) { 281 line("package com.android.systemui.plugins;") 282 line() 283 284 line("import java.util.Map;") 285 line("import java.util.ArrayList;") 286 line("import java.util.HashSet;") 287 line("import android.util.Log;") 288 line("import static java.util.Map.entry;") 289 line() 290 291 for (impTarget in additionalImports) { 292 line("import $impTarget;") 293 } 294 line() 295 296 braceBlock("public final class PluginProtector") { 297 line("private PluginProtector() { }") 298 line() 299 300 line("private static final String TAG = \"PluginProtector\";") 301 line() 302 303 // Untyped factory SAM, private to this type. 304 braceBlock("private interface Factory") { 305 line("Object create(Object plugin, ProtectedPluginListener listener);") 306 } 307 line() 308 309 // Store a reference to each `protect` method in a map by interface type. 310 parenBlock("private static final Map<Class, Factory> sFactories = Map.ofEntries") { 311 var isFirst = true 312 for (target in targets.values) { 313 if (!isFirst) completeLine(",") 314 target.apply { 315 startLine("entry($sourceName.class, ") 316 appendLine("(p, h) -> $outputName.protect(($sourceName)p, h))") 317 } 318 isFirst = false 319 } 320 } 321 completeLine(";") 322 line() 323 324 // Lookup the relevant factory based on the instance type, if not found return null. 325 parenBlock("public static <T> T tryProtect") { 326 line("T target,") 327 line("ProtectedPluginListener listener") 328 } 329 braceBlock { 330 // Accumulate interfaces from type and all base types 331 line("HashSet<Class> interfaces = new HashSet<Class>();") 332 line("Class current = target.getClass();") 333 braceBlock("while (current != null)") { 334 braceBlock("for (Class cls : current.getInterfaces())") { 335 line("interfaces.add(cls);") 336 } 337 line("current = current.getSuperclass();") 338 } 339 line() 340 341 // Check if any of the interfaces are marked protectable 342 line("int candidateCount = 0;") 343 line("Factory candidateFactory = null;") 344 braceBlock("for (Class cls : interfaces)") { 345 line("Factory factory = sFactories.get(cls);") 346 braceBlock("if (factory != null)") { 347 line("candidateFactory = factory;") 348 line("candidateCount++;") 349 } 350 } 351 line() 352 353 // No match, return null 354 braceBlock("if (candidateFactory == null)") { 355 line("Log.i(TAG, \"Wasn't able to wrap \" + target);") 356 line("return null;") 357 } 358 359 // Multiple matches, not supported 360 braceBlock("if (candidateCount >= 2)") { 361 var error = "Plugin implements more than one protected interface" 362 line("throw new UnsupportedOperationException(\"$error\");") 363 } 364 365 // Call the factory and wrap the target object 366 line("return (T)candidateFactory.create(target, listener);") 367 } 368 line() 369 370 // Wraps the target with the appropriate generated proxy if it exists. 371 parenBlock("public static <T> T protectIfAble") { 372 line("T target,") 373 line("ProtectedPluginListener listener") 374 } 375 braceBlock { 376 line("T result = tryProtect(target, listener);") 377 line("return result != null ? result : target;") 378 } 379 line() 380 } 381 } 382 383 return true 384 } 385 } 386