• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 @file:JvmName("Main")
17 
18 package com.android.checkflaggedapis
19 
20 import android.aconfig.Aconfig
21 import com.android.tools.metalava.model.BaseItemVisitor
22 import com.android.tools.metalava.model.ClassItem
23 import com.android.tools.metalava.model.FieldItem
24 import com.android.tools.metalava.model.Item
25 import com.android.tools.metalava.model.MethodItem
26 import com.android.tools.metalava.model.text.ApiFile
27 import com.github.ajalt.clikt.core.CliktCommand
28 import com.github.ajalt.clikt.core.ProgramResult
29 import com.github.ajalt.clikt.core.subcommands
30 import com.github.ajalt.clikt.parameters.options.help
31 import com.github.ajalt.clikt.parameters.options.option
32 import com.github.ajalt.clikt.parameters.options.required
33 import com.github.ajalt.clikt.parameters.types.path
34 import java.io.InputStream
35 import javax.xml.parsers.DocumentBuilderFactory
36 import org.w3c.dom.Node
37 
38 /**
39  * Class representing the fully qualified name of a class, method or field.
40  *
41  * This tool reads a multitude of input formats all of which represents the fully qualified path to
42  * a Java symbol slightly differently. To keep things consistent, all parsed APIs are converted to
43  * Symbols.
44  *
45  * Symbols are encoded using the format similar to the one described in section 4.3.2 of the JVM
46  * spec [1], that is, "package.class.inner-class.method(int, int[], android.util.Clazz)" is
47  * represented as
48  * <pre>
49  *   package.class.inner-class.method(II[Landroid/util/Clazz;)
50  * <pre>
51  *
52  * Where possible, the format has been simplified (to make translation of the
53  * various input formats easier): for instance, only / is used as delimiter (#
54  * and $ are never used).
55  *
56  * 1. https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.3.2
57  */
58 internal sealed class Symbol {
59   companion object {
60     private val FORBIDDEN_CHARS = listOf('#', '$', '.')
61 
62     fun createClass(clazz: String, superclass: String?, interfaces: Set<String>): Symbol {
63       return ClassSymbol(
64           toInternalFormat(clazz),
65           superclass?.let { toInternalFormat(it) },
66           interfaces.map { toInternalFormat(it) }.toSet())
67     }
68 
69     fun createField(clazz: String, field: String): Symbol {
70       require(!field.contains("(") && !field.contains(")"))
71       return MemberSymbol(toInternalFormat(clazz), toInternalFormat(field))
72     }
73 
74     fun createMethod(clazz: String, method: String): Symbol {
75       return MemberSymbol(toInternalFormat(clazz), toInternalFormat(method))
76     }
77 
78     protected fun toInternalFormat(name: String): String {
79       var internalName = name
80       for (ch in FORBIDDEN_CHARS) {
81         internalName = internalName.replace(ch, '/')
82       }
83       return internalName
84     }
85   }
86 
87   abstract fun toPrettyString(): String
88 }
89 
90 internal data class ClassSymbol(
91     val clazz: String,
92     val superclass: String?,
93     val interfaces: Set<String>
94 ) : Symbol() {
toPrettyStringnull95   override fun toPrettyString(): String = "$clazz"
96 }
97 
98 internal data class MemberSymbol(val clazz: String, val member: String) : Symbol() {
99   override fun toPrettyString(): String = "$clazz/$member"
100 }
101 
102 /**
103  * Class representing the fully qualified name of an aconfig flag.
104  *
105  * This includes both the flag's package and name, separated by a dot, e.g.:
106  * <pre>
107  *   com.android.aconfig.test.disabled_ro
108  * <pre>
109  */
110 @JvmInline
111 internal value class Flag(val name: String) {
toStringnull112   override fun toString(): String = name.toString()
113 }
114 
115 internal sealed class ApiError {
116   abstract val symbol: Symbol
117   abstract val flag: Flag
118 }
119 
120 internal data class EnabledFlaggedApiNotPresentError(
121     override val symbol: Symbol,
122     override val flag: Flag
123 ) : ApiError() {
toStringnull124   override fun toString(): String {
125     return "error: enabled @FlaggedApi not present in built artifact: symbol=${symbol.toPrettyString()} flag=$flag"
126   }
127 }
128 
129 internal data class DisabledFlaggedApiIsPresentError(
130     override val symbol: Symbol,
131     override val flag: Flag
132 ) : ApiError() {
toStringnull133   override fun toString(): String {
134     return "error: disabled @FlaggedApi is present in built artifact: symbol=${symbol.toPrettyString()} flag=$flag"
135   }
136 }
137 
138 internal data class UnknownFlagError(override val symbol: Symbol, override val flag: Flag) :
139     ApiError() {
toStringnull140   override fun toString(): String {
141     return "error: unknown flag: symbol=${symbol.toPrettyString()} flag=$flag"
142   }
143 }
144 
145 val ARG_API_SIGNATURE = "--api-signature"
146 val ARG_API_SIGNATURE_HELP =
147     """
148 Path to API signature file.
149 Usually named *current.txt.
150 Tip: `m frameworks-base-api-current.txt` will generate a file that includes all platform and mainline APIs.
151 """
152 
153 val ARG_FLAG_VALUES = "--flag-values"
154 val ARG_FLAG_VALUES_HELP =
155     """
156 Path to aconfig parsed_flags binary proto file.
157 Tip: `m all_aconfig_declarations` will generate a file that includes all information about all flags.
158 """
159 
160 val ARG_API_VERSIONS = "--api-versions"
161 val ARG_API_VERSIONS_HELP =
162     """
163 Path to API versions XML file.
164 Usually named xml-versions.xml.
165 Tip: `m sdk dist` will generate a file that includes all platform and mainline APIs.
166 """
167 
168 class MainCommand : CliktCommand() {
runnull169   override fun run() {}
170 }
171 
172 class CheckCommand :
173     CliktCommand(
174         help =
175             """
176 Check that all flagged APIs are used in the correct way.
177 
178 This tool reads the API signature file and checks that all flagged APIs are used in the correct way.
179 
180 The tool will exit with a non-zero exit code if any flagged APIs are found to be used in the incorrect way.
181 """) {
182   private val apiSignaturePath by
183       option(ARG_API_SIGNATURE)
184           .help(ARG_API_SIGNATURE_HELP)
185           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
186           .required()
187   private val flagValuesPath by
188       option(ARG_FLAG_VALUES)
189           .help(ARG_FLAG_VALUES_HELP)
190           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
191           .required()
192   private val apiVersionsPath by
193       option(ARG_API_VERSIONS)
194           .help(ARG_API_VERSIONS_HELP)
195           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
196           .required()
197 
runnull198   override fun run() {
199     val flaggedSymbols =
200         apiSignaturePath.toFile().inputStream().use {
201           parseApiSignature(apiSignaturePath.toString(), it)
202         }
203     val flags = flagValuesPath.toFile().inputStream().use { parseFlagValues(it) }
204     val exportedSymbols = apiVersionsPath.toFile().inputStream().use { parseApiVersions(it) }
205     val errors = findErrors(flaggedSymbols, flags, exportedSymbols)
206     for (e in errors) {
207       println(e)
208     }
209     throw ProgramResult(errors.size)
210   }
211 }
212 
213 class ListCommand :
214     CliktCommand(
215         help =
216             """
217 List all flagged APIs and corresponding flags.
218 
219 The output format is "<fully-qualified-name-of-flag> <state-of-flag> <API>", one line per API.
220 
221 The output can be post-processed by e.g. piping it to grep to filter out only enabled APIs, or all APIs guarded by a given flag.
222 """) {
223   private val apiSignaturePath by
224       option(ARG_API_SIGNATURE)
225           .help(ARG_API_SIGNATURE_HELP)
226           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
227           .required()
228   private val flagValuesPath by
229       option(ARG_FLAG_VALUES)
230           .help(ARG_FLAG_VALUES_HELP)
231           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
232           .required()
233 
runnull234   override fun run() {
235     val flaggedSymbols =
236         apiSignaturePath.toFile().inputStream().use {
237           parseApiSignature(apiSignaturePath.toString(), it)
238         }
239     val flags = flagValuesPath.toFile().inputStream().use { parseFlagValues(it) }
240     val output = listFlaggedApis(flaggedSymbols, flags)
241     if (output.isNotEmpty()) {
242       println(output.joinToString("\n"))
243     }
244   }
245 }
246 
parseApiSignaturenull247 internal fun parseApiSignature(path: String, input: InputStream): Set<Pair<Symbol, Flag>> {
248   val output = mutableSetOf<Pair<Symbol, Flag>>()
249   val visitor =
250       object : BaseItemVisitor() {
251         override fun visitClass(cls: ClassItem) {
252           getFlagOrNull(cls)?.let { flag ->
253             val symbol =
254                 Symbol.createClass(
255                     cls.baselineElementId(),
256                     if (cls.isInterface()) {
257                       "java/lang/Object"
258                     } else {
259                       cls.superClass()?.baselineElementId()
260                     },
261                     cls.allInterfaces()
262                         .map { it.baselineElementId() }
263                         .filter { it != cls.baselineElementId() }
264                         .toSet())
265             output.add(Pair(symbol, flag))
266           }
267         }
268 
269         override fun visitField(field: FieldItem) {
270           getFlagOrNull(field)?.let { flag ->
271             val symbol =
272                 Symbol.createField(field.containingClass().baselineElementId(), field.name())
273             output.add(Pair(symbol, flag))
274           }
275         }
276 
277         override fun visitMethod(method: MethodItem) {
278           getFlagOrNull(method)?.let { flag ->
279             val methodName = buildString {
280               append(method.name())
281               append("(")
282               method.parameters().joinTo(this, separator = "") { it.type().internalName() }
283               append(")")
284             }
285             val symbol = Symbol.createMethod(method.containingClass().qualifiedName(), methodName)
286             output.add(Pair(symbol, flag))
287           }
288         }
289 
290         private fun getFlagOrNull(item: Item): Flag? {
291           return item.modifiers
292               .findAnnotation("android.annotation.FlaggedApi")
293               ?.findAttribute("value")
294               ?.value
295               ?.let { Flag(it.value() as String) }
296         }
297       }
298   val codebase = ApiFile.parseApi(path, input)
299   codebase.accept(visitor)
300   return output
301 }
302 
parseFlagValuesnull303 internal fun parseFlagValues(input: InputStream): Map<Flag, Boolean> {
304   val parsedFlags = Aconfig.parsed_flags.parseFrom(input).getParsedFlagList()
305   return parsedFlags.associateBy(
306       { Flag("${it.getPackage()}.${it.getName()}") },
307       { it.getState() == Aconfig.flag_state.ENABLED })
308 }
309 
parseApiVersionsnull310 internal fun parseApiVersions(input: InputStream): Set<Symbol> {
311   fun Node.getAttribute(name: String): String? = getAttributes()?.getNamedItem(name)?.getNodeValue()
312 
313   val output = mutableSetOf<Symbol>()
314   val factory = DocumentBuilderFactory.newInstance()
315   val parser = factory.newDocumentBuilder()
316   val document = parser.parse(input)
317 
318   val classes = document.getElementsByTagName("class")
319   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
320   for (i in 0.rangeUntil(classes.getLength())) {
321     val cls = classes.item(i)
322     val className =
323         requireNotNull(cls.getAttribute("name")) {
324           "Bad XML: <class> element without name attribute"
325         }
326     var superclass: String? = null
327     val interfaces = mutableSetOf<String>()
328     val children = cls.getChildNodes()
329     for (j in 0.rangeUntil(children.getLength())) {
330       val child = children.item(j)
331       when (child.getNodeName()) {
332         "extends" -> {
333           superclass =
334               requireNotNull(child.getAttribute("name")) {
335                 "Bad XML: <extends> element without name attribute"
336               }
337         }
338         "implements" -> {
339           val interfaceName =
340               requireNotNull(child.getAttribute("name")) {
341                 "Bad XML: <implements> element without name attribute"
342               }
343           interfaces.add(interfaceName)
344         }
345       }
346     }
347     output.add(Symbol.createClass(className, superclass, interfaces))
348   }
349 
350   val fields = document.getElementsByTagName("field")
351   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
352   for (i in 0.rangeUntil(fields.getLength())) {
353     val field = fields.item(i)
354     val fieldName =
355         requireNotNull(field.getAttribute("name")) {
356           "Bad XML: <field> element without name attribute"
357         }
358     val className =
359         requireNotNull(field.getParentNode()?.getAttribute("name")) {
360           "Bad XML: top level <field> element"
361         }
362     output.add(Symbol.createField(className, fieldName))
363   }
364 
365   val methods = document.getElementsByTagName("method")
366   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
367   for (i in 0.rangeUntil(methods.getLength())) {
368     val method = methods.item(i)
369     val methodSignature =
370         requireNotNull(method.getAttribute("name")) {
371           "Bad XML: <method> element without name attribute"
372         }
373     val methodSignatureParts = methodSignature.split(Regex("\\(|\\)"))
374     if (methodSignatureParts.size != 3) {
375       throw Exception("Bad XML: method signature '$methodSignature'")
376     }
377     var (methodName, methodArgs, _) = methodSignatureParts
378     val packageAndClassName =
379         requireNotNull(method.getParentNode()?.getAttribute("name")) {
380               "Bad XML: top level <method> element, or <class> element missing name attribute"
381             }
382             .replace("$", "/")
383     if (methodName == "<init>") {
384       methodName = packageAndClassName.split("/").last()
385     }
386     output.add(Symbol.createMethod(packageAndClassName, "$methodName($methodArgs)"))
387   }
388 
389   return output
390 }
391 
392 /**
393  * Find errors in the given data.
394  *
395  * @param flaggedSymbolsInSource the set of symbols that are flagged in the source code
396  * @param flags the set of flags and their values
397  * @param symbolsInOutput the set of symbols that are present in the output
398  * @return the set of errors found
399  */
findErrorsnull400 internal fun findErrors(
401     flaggedSymbolsInSource: Set<Pair<Symbol, Flag>>,
402     flags: Map<Flag, Boolean>,
403     symbolsInOutput: Set<Symbol>
404 ): Set<ApiError> {
405   fun Set<Symbol>.containsSymbol(symbol: Symbol): Boolean {
406     // trivial case: the symbol is explicitly listed in api-versions.xml
407     if (contains(symbol)) {
408       return true
409     }
410 
411     // non-trivial case: the symbol could be part of the surrounding class'
412     // super class or interfaces
413     val (className, memberName) =
414         when (symbol) {
415           is ClassSymbol -> return false
416           is MemberSymbol -> {
417             Pair(symbol.clazz, symbol.member)
418           }
419         }
420     val clazz = find { it is ClassSymbol && it.clazz == className } as? ClassSymbol?
421     if (clazz == null) {
422       return false
423     }
424 
425     for (interfaceName in clazz.interfaces) {
426       // createMethod is the same as createField, except it allows parenthesis
427       val interfaceSymbol = Symbol.createMethod(interfaceName, memberName)
428       if (contains(interfaceSymbol)) {
429         return true
430       }
431     }
432 
433     if (clazz.superclass != null) {
434       val superclassSymbol = Symbol.createMethod(clazz.superclass, memberName)
435       return containsSymbol(superclassSymbol)
436     }
437 
438     return false
439   }
440 
441   /**
442    * Returns whether the given flag is enabled for the given symbol.
443    *
444    * A flagged member inside a flagged class is ignored (and the flag value considered disabled) if
445    * the class' flag is disabled.
446    *
447    * @param symbol the symbol to check
448    * @param flag the flag to check
449    * @return whether the flag is enabled for the given symbol
450    */
451   fun isFlagEnabledForSymbol(symbol: Symbol, flag: Flag): Boolean {
452     when (symbol) {
453       is ClassSymbol -> return flags.getValue(flag)
454       is MemberSymbol -> {
455         val memberFlagValue = flags.getValue(flag)
456         if (!memberFlagValue) {
457           return false
458         }
459         // Special case: if the MemberSymbol's flag is enabled, but the outer
460         // ClassSymbol's flag (if the class is flagged) is disabled, consider
461         // the MemberSymbol's flag as disabled:
462         //
463         //   @FlaggedApi(this-flag-is-disabled) Clazz {
464         //       @FlaggedApi(this-flag-is-enabled) method(); // The Clazz' flag "wins"
465         //   }
466         //
467         // Note: the current implementation does not handle nested classes.
468         val classFlagValue =
469             flaggedSymbolsInSource
470                 .find { it.first.toPrettyString() == symbol.clazz }
471                 ?.let { flags.getValue(it.second) }
472                 ?: true
473         return classFlagValue
474       }
475     }
476   }
477 
478   val errors = mutableSetOf<ApiError>()
479   for ((symbol, flag) in flaggedSymbolsInSource) {
480     try {
481       if (isFlagEnabledForSymbol(symbol, flag)) {
482         if (!symbolsInOutput.containsSymbol(symbol)) {
483           errors.add(EnabledFlaggedApiNotPresentError(symbol, flag))
484         }
485       } else {
486         if (symbolsInOutput.containsSymbol(symbol)) {
487           errors.add(DisabledFlaggedApiIsPresentError(symbol, flag))
488         }
489       }
490     } catch (e: NoSuchElementException) {
491       errors.add(UnknownFlagError(symbol, flag))
492     }
493   }
494   return errors
495 }
496 
497 /**
498  * Collect all known info about all @FlaggedApi annotated APIs.
499  *
500  * Each API will be represented as a String, on the format
501  * <pre>
502  *   &lt;fully-qualified-name-of-flag&lt; &lt;state-of-flag&lt; &lt;API&lt;
503  * </pre>
504  *
505  * @param flaggedSymbolsInSource the set of symbols that are flagged in the source code
506  * @param flags the set of flags and their values
507  * @return a list of Strings encoding API data using the format described above, sorted
508  *   alphabetically
509  */
listFlaggedApisnull510 internal fun listFlaggedApis(
511     flaggedSymbolsInSource: Set<Pair<Symbol, Flag>>,
512     flags: Map<Flag, Boolean>
513 ): List<String> {
514   val output = mutableListOf<String>()
515   for ((symbol, flag) in flaggedSymbolsInSource) {
516     val flagState =
517         when (flags.get(flag)) {
518           true -> "ENABLED"
519           false -> "DISABLED"
520           null -> "UNKNOWN"
521         }
522     output.add("$flag $flagState ${symbol.toPrettyString()}")
523   }
524   output.sort()
525   return output
526 }
527 
mainnull528 fun main(args: Array<String>) = MainCommand().subcommands(CheckCommand(), ListCommand()).main(args)
529