• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 @file:JvmName("Main")
17 
18 package com.android.checkflaggedapis
19 
20 import android.aconfig.Aconfig
21 import com.android.tools.metalava.model.BaseItemVisitor
22 import com.android.tools.metalava.model.CallableItem
23 import com.android.tools.metalava.model.ClassItem
24 import com.android.tools.metalava.model.FieldItem
25 import com.android.tools.metalava.model.Item
26 import com.android.tools.metalava.model.text.ApiFile
27 import com.github.ajalt.clikt.core.CliktCommand
28 import com.github.ajalt.clikt.core.ProgramResult
29 import com.github.ajalt.clikt.core.subcommands
30 import com.github.ajalt.clikt.parameters.options.help
31 import com.github.ajalt.clikt.parameters.options.option
32 import com.github.ajalt.clikt.parameters.options.required
33 import com.github.ajalt.clikt.parameters.types.path
34 import java.io.InputStream
35 import javax.xml.parsers.DocumentBuilderFactory
36 import org.w3c.dom.Node
37 
38 /**
39  * Class representing the fully qualified name of a class, method or field.
40  *
41  * This tool reads a multitude of input formats all of which represents the fully qualified path to
42  * a Java symbol slightly differently. To keep things consistent, all parsed APIs are converted to
43  * Symbols.
44  *
45  * Symbols are encoded using the format similar to the one described in section 4.3.2 of the JVM
46  * spec [1], that is, "package.class.inner-class.method(int, int[], android.util.Clazz)" is
47  * represented as
48  * <pre>
49  *   package.class.inner-class.method(II[Landroid/util/Clazz;)
50  * <pre>
51  *
52  * Where possible, the format has been simplified (to make translation of the
53  * various input formats easier): for instance, only / is used as delimiter (#
54  * and $ are never used).
55  *
56  * 1. https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.3.2
57  */
58 internal sealed class Symbol {
59   companion object {
60     private val FORBIDDEN_CHARS = listOf('#', '$', '.')
61 
62     fun createClass(clazz: String, superclass: String?, interfaces: Set<String>): Symbol {
63       return ClassSymbol(
64           toInternalFormat(clazz),
65           superclass?.let { toInternalFormat(it) },
66           interfaces.map { toInternalFormat(it) }.toSet())
67     }
68 
69     fun createField(clazz: String, field: String): Symbol {
70       require(!field.contains("(") && !field.contains(")"))
71       return MemberSymbol(toInternalFormat(clazz), toInternalFormat(field))
72     }
73 
74     fun createMethod(clazz: String, method: String): Symbol {
75       return MemberSymbol(toInternalFormat(clazz), toInternalFormat(method))
76     }
77 
78     protected fun toInternalFormat(name: String): String {
79       var internalName = name
80       for (ch in FORBIDDEN_CHARS) {
81         internalName = internalName.replace(ch, '/')
82       }
83       return internalName
84     }
85   }
86 
87   abstract fun toPrettyString(): String
88 }
89 
90 internal data class ClassSymbol(
91     val clazz: String,
92     val superclass: String?,
93     val interfaces: Set<String>
94 ) : Symbol() {
toPrettyStringnull95   override fun toPrettyString(): String = "$clazz"
96 }
97 
98 internal data class MemberSymbol(val clazz: String, val member: String) : Symbol() {
99   override fun toPrettyString(): String = "$clazz/$member"
100 }
101 
102 /**
103  * Class representing the fully qualified name of an aconfig flag.
104  *
105  * This includes both the flag's package and name, separated by a dot, e.g.:
106  * <pre>
107  *   com.android.aconfig.test.disabled_ro
108  * <pre>
109  */
110 @JvmInline
111 internal value class Flag(val name: String) {
toStringnull112   override fun toString(): String = name.toString()
113 }
114 
115 internal sealed class ApiError {
116   abstract val symbol: Symbol
117   abstract val flag: Flag
118 }
119 
120 internal data class EnabledFlaggedApiNotPresentError(
121     override val symbol: Symbol,
122     override val flag: Flag
123 ) : ApiError() {
toStringnull124   override fun toString(): String {
125     return "error: enabled @FlaggedApi not present in built artifact: symbol=${symbol.toPrettyString()} flag=$flag"
126   }
127 }
128 
129 internal data class DisabledFlaggedApiIsPresentError(
130     override val symbol: Symbol,
131     override val flag: Flag
132 ) : ApiError() {
toStringnull133   override fun toString(): String {
134     return "error: disabled @FlaggedApi is present in built artifact: symbol=${symbol.toPrettyString()} flag=$flag"
135   }
136 }
137 
138 internal data class UnknownFlagError(override val symbol: Symbol, override val flag: Flag) :
139     ApiError() {
toStringnull140   override fun toString(): String {
141     return "error: unknown flag: symbol=${symbol.toPrettyString()} flag=$flag"
142   }
143 }
144 
145 val ARG_API_SIGNATURE = "--api-signature"
146 val ARG_API_SIGNATURE_HELP =
147     """
148 Path to API signature file.
149 Usually named *current.txt.
150 Tip: `m frameworks-base-api-current.txt` will generate a file that includes all platform and mainline APIs.
151 """
152 
153 val ARG_FLAG_VALUES = "--flag-values"
154 val ARG_FLAG_VALUES_HELP =
155     """
156 Path to aconfig parsed_flags binary proto file.
157 Tip: `m all_aconfig_declarations` will generate a file that includes all information about all flags.
158 """
159 
160 val ARG_API_VERSIONS = "--api-versions"
161 val ARG_API_VERSIONS_HELP =
162     """
163 Path to API versions XML file.
164 Usually named xml-versions.xml.
165 Tip: `m sdk dist` will generate a file that includes all platform and mainline APIs.
166 """
167 
168 class MainCommand : CliktCommand() {
runnull169   override fun run() {}
170 }
171 
172 class CheckCommand :
173     CliktCommand(
174         help =
175             """
176 Check that all flagged APIs are used in the correct way.
177 
178 This tool reads the API signature file and checks that all flagged APIs are used in the correct way.
179 
180 The tool will exit with a non-zero exit code if any flagged APIs are found to be used in the incorrect way.
181 """) {
182   private val apiSignaturePath by
183       option(ARG_API_SIGNATURE)
184           .help(ARG_API_SIGNATURE_HELP)
185           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
186           .required()
187   private val flagValuesPath by
188       option(ARG_FLAG_VALUES)
189           .help(ARG_FLAG_VALUES_HELP)
190           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
191           .required()
192   private val apiVersionsPath by
193       option(ARG_API_VERSIONS)
194           .help(ARG_API_VERSIONS_HELP)
195           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
196           .required()
197 
runnull198   override fun run() {
199     val flaggedSymbols =
200         apiSignaturePath.toFile().inputStream().use {
201           parseApiSignature(apiSignaturePath.toString(), it)
202         }
203     val flags = flagValuesPath.toFile().inputStream().use { parseFlagValues(it) }
204     val exportedSymbols = apiVersionsPath.toFile().inputStream().use { parseApiVersions(it) }
205     val errors = findErrors(flaggedSymbols, flags, exportedSymbols)
206     for (e in errors) {
207       println(e)
208     }
209     throw ProgramResult(errors.size)
210   }
211 }
212 
213 class ListCommand :
214     CliktCommand(
215         help =
216             """
217 List all flagged APIs and corresponding flags.
218 
219 The output format is "<fully-qualified-name-of-flag> <state-of-flag> <API>", one line per API.
220 
221 The output can be post-processed by e.g. piping it to grep to filter out only enabled APIs, or all APIs guarded by a given flag.
222 """) {
223   private val apiSignaturePath by
224       option(ARG_API_SIGNATURE)
225           .help(ARG_API_SIGNATURE_HELP)
226           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
227           .required()
228   private val flagValuesPath by
229       option(ARG_FLAG_VALUES)
230           .help(ARG_FLAG_VALUES_HELP)
231           .path(mustExist = true, canBeDir = false, mustBeReadable = true)
232           .required()
233 
runnull234   override fun run() {
235     val flaggedSymbols =
236         apiSignaturePath.toFile().inputStream().use {
237           parseApiSignature(apiSignaturePath.toString(), it)
238         }
239     val flags = flagValuesPath.toFile().inputStream().use { parseFlagValues(it) }
240     val output = listFlaggedApis(flaggedSymbols, flags)
241     if (output.isNotEmpty()) {
242       println(output.joinToString("\n"))
243     }
244   }
245 }
246 
parseApiSignaturenull247 internal fun parseApiSignature(path: String, input: InputStream): Set<Pair<Symbol, Flag>> {
248   val output = mutableSetOf<Pair<Symbol, Flag>>()
249   val visitor =
250       object : BaseItemVisitor() {
251         override fun visitClass(cls: ClassItem) {
252           getFlagOrNull(cls)?.let { flag ->
253             val symbol =
254                 Symbol.createClass(
255                     cls.baselineElementId(),
256                     if (cls.isInterface()) {
257                       "java/lang/Object"
258                     } else {
259                       cls.superClass()?.baselineElementId()
260                     },
261                     cls.allInterfaces()
262                         .map { it.baselineElementId() }
263                         .filter { it != cls.baselineElementId() }
264                         .toSet())
265             output.add(Pair(symbol, flag))
266           }
267         }
268 
269         override fun visitField(field: FieldItem) {
270           getFlagOrNull(field)?.let { flag ->
271             val symbol =
272                 Symbol.createField(field.containingClass().baselineElementId(), field.name())
273             output.add(Pair(symbol, flag))
274           }
275         }
276 
277         override fun visitCallable(callable: CallableItem) {
278           getFlagOrNull(callable)?.let { flag ->
279             val callableSignature = buildString {
280               append(callable.name())
281               append("(")
282               callable.parameters().joinTo(this, separator = "") { it.type().internalName() }
283               append(")")
284             }
285             val symbol =
286                 Symbol.createMethod(callable.containingClass().qualifiedName(), callableSignature)
287             output.add(Pair(symbol, flag))
288           }
289         }
290 
291         private fun getFlagOrNull(item: Item): Flag? {
292           return item.modifiers
293               .findAnnotation("android.annotation.FlaggedApi")
294               ?.findAttribute("value")
295               ?.legacyValue
296               ?.let { Flag(it.value() as String) }
297         }
298       }
299   val codebase = ApiFile.parseApi(path, input)
300   codebase.accept(visitor)
301   return output
302 }
303 
parseFlagValuesnull304 internal fun parseFlagValues(input: InputStream): Map<Flag, Boolean> {
305   val parsedFlags = Aconfig.parsed_flags.parseFrom(input).getParsedFlagList()
306   return parsedFlags.associateBy(
307       { Flag("${it.getPackage()}.${it.getName()}") },
308       { it.getState() == Aconfig.flag_state.ENABLED })
309 }
310 
parseApiVersionsnull311 internal fun parseApiVersions(input: InputStream): Set<Symbol> {
312   fun Node.getAttribute(name: String): String? = getAttributes()?.getNamedItem(name)?.getNodeValue()
313 
314   val output = mutableSetOf<Symbol>()
315   val factory = DocumentBuilderFactory.newInstance()
316   val parser = factory.newDocumentBuilder()
317   val document = parser.parse(input)
318 
319   val classes = document.getElementsByTagName("class")
320   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
321   for (i in 0.rangeUntil(classes.getLength())) {
322     val cls = classes.item(i)
323     val className =
324         requireNotNull(cls.getAttribute("name")) {
325           "Bad XML: <class> element without name attribute"
326         }
327     var superclass: String? = null
328     val interfaces = mutableSetOf<String>()
329     val children = cls.getChildNodes()
330     for (j in 0.rangeUntil(children.getLength())) {
331       val child = children.item(j)
332       when (child.getNodeName()) {
333         "extends" -> {
334           superclass =
335               requireNotNull(child.getAttribute("name")) {
336                 "Bad XML: <extends> element without name attribute"
337               }
338         }
339         "implements" -> {
340           val interfaceName =
341               requireNotNull(child.getAttribute("name")) {
342                 "Bad XML: <implements> element without name attribute"
343               }
344           interfaces.add(interfaceName)
345         }
346       }
347     }
348     output.add(Symbol.createClass(className, superclass, interfaces))
349   }
350 
351   val fields = document.getElementsByTagName("field")
352   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
353   for (i in 0.rangeUntil(fields.getLength())) {
354     val field = fields.item(i)
355     val fieldName =
356         requireNotNull(field.getAttribute("name")) {
357           "Bad XML: <field> element without name attribute"
358         }
359     val className =
360         requireNotNull(field.getParentNode()?.getAttribute("name")) {
361           "Bad XML: top level <field> element"
362         }
363     output.add(Symbol.createField(className, fieldName))
364   }
365 
366   val methods = document.getElementsByTagName("method")
367   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
368   for (i in 0.rangeUntil(methods.getLength())) {
369     val method = methods.item(i)
370     val methodSignature =
371         requireNotNull(method.getAttribute("name")) {
372           "Bad XML: <method> element without name attribute"
373         }
374     val methodSignatureParts = methodSignature.split(Regex("\\(|\\)"))
375     if (methodSignatureParts.size != 3) {
376       throw Exception("Bad XML: method signature '$methodSignature'")
377     }
378     var (methodName, methodArgs, _) = methodSignatureParts
379     val packageAndClassName =
380         requireNotNull(method.getParentNode()?.getAttribute("name")) {
381               "Bad XML: top level <method> element, or <class> element missing name attribute"
382             }
383             .replace("$", "/")
384     if (methodName == "<init>") {
385       methodName = packageAndClassName.split("/").last()
386     }
387     output.add(Symbol.createMethod(packageAndClassName, "$methodName($methodArgs)"))
388   }
389 
390   return output
391 }
392 
393 /**
394  * Find errors in the given data.
395  *
396  * @param flaggedSymbolsInSource the set of symbols that are flagged in the source code
397  * @param flags the set of flags and their values
398  * @param symbolsInOutput the set of symbols that are present in the output
399  * @return the set of errors found
400  */
findErrorsnull401 internal fun findErrors(
402     flaggedSymbolsInSource: Set<Pair<Symbol, Flag>>,
403     flags: Map<Flag, Boolean>,
404     symbolsInOutput: Set<Symbol>
405 ): Set<ApiError> {
406   fun Set<Symbol>.containsSymbol(symbol: Symbol): Boolean {
407     // trivial case: the symbol is explicitly listed in api-versions.xml
408     if (contains(symbol)) {
409       return true
410     }
411 
412     // non-trivial case: the symbol could be part of the surrounding class'
413     // super class or interfaces
414     val (className, memberName) =
415         when (symbol) {
416           is ClassSymbol -> return false
417           is MemberSymbol -> {
418             Pair(symbol.clazz, symbol.member)
419           }
420         }
421     val clazz = find { it is ClassSymbol && it.clazz == className } as? ClassSymbol?
422     if (clazz == null) {
423       return false
424     }
425 
426     for (interfaceName in clazz.interfaces) {
427       // createMethod is the same as createField, except it allows parenthesis
428       val interfaceSymbol = Symbol.createMethod(interfaceName, memberName)
429       if (contains(interfaceSymbol)) {
430         return true
431       }
432     }
433 
434     if (clazz.superclass != null) {
435       val superclassSymbol = Symbol.createMethod(clazz.superclass, memberName)
436       return containsSymbol(superclassSymbol)
437     }
438 
439     return false
440   }
441 
442   /**
443    * Returns whether the given flag is enabled for the given symbol.
444    *
445    * A flagged member inside a flagged class is ignored (and the flag value considered disabled) if
446    * the class' flag is disabled.
447    *
448    * @param symbol the symbol to check
449    * @param flag the flag to check
450    * @return whether the flag is enabled for the given symbol
451    */
452   fun isFlagEnabledForSymbol(symbol: Symbol, flag: Flag): Boolean {
453     when (symbol) {
454       is ClassSymbol -> return flags.getValue(flag)
455       is MemberSymbol -> {
456         val memberFlagValue = flags.getValue(flag)
457         if (!memberFlagValue) {
458           return false
459         }
460         // Special case: if the MemberSymbol's flag is enabled, but the outer
461         // ClassSymbol's flag (if the class is flagged) is disabled, consider
462         // the MemberSymbol's flag as disabled:
463         //
464         //   @FlaggedApi(this-flag-is-disabled) Clazz {
465         //       @FlaggedApi(this-flag-is-enabled) method(); // The Clazz' flag "wins"
466         //   }
467         //
468         // Note: the current implementation does not handle nested classes.
469         val classFlagValue =
470             flaggedSymbolsInSource
471                 .find { it.first.toPrettyString() == symbol.clazz }
472                 ?.let { flags.getValue(it.second) } ?: true
473         return classFlagValue
474       }
475     }
476   }
477 
478   val errors = mutableSetOf<ApiError>()
479   for ((symbol, flag) in flaggedSymbolsInSource) {
480     try {
481       if (isFlagEnabledForSymbol(symbol, flag)) {
482         if (!symbolsInOutput.containsSymbol(symbol)) {
483           errors.add(EnabledFlaggedApiNotPresentError(symbol, flag))
484         }
485       } else {
486         if (symbolsInOutput.containsSymbol(symbol)) {
487           errors.add(DisabledFlaggedApiIsPresentError(symbol, flag))
488         }
489       }
490     } catch (e: NoSuchElementException) {
491       errors.add(UnknownFlagError(symbol, flag))
492     }
493   }
494   return errors
495 }
496 
497 /**
498  * Collect all known info about all @FlaggedApi annotated APIs.
499  *
500  * Each API will be represented as a String, on the format
501  * <pre>
502  *   &lt;fully-qualified-name-of-flag&lt; &lt;state-of-flag&lt; &lt;API&lt;
503  * </pre>
504  *
505  * @param flaggedSymbolsInSource the set of symbols that are flagged in the source code
506  * @param flags the set of flags and their values
507  * @return a list of Strings encoding API data using the format described above, sorted
508  *   alphabetically
509  */
listFlaggedApisnull510 internal fun listFlaggedApis(
511     flaggedSymbolsInSource: Set<Pair<Symbol, Flag>>,
512     flags: Map<Flag, Boolean>
513 ): List<String> {
514   val output = mutableListOf<String>()
515   for ((symbol, flag) in flaggedSymbolsInSource) {
516     val flagState =
517         when (flags.get(flag)) {
518           true -> "ENABLED"
519           false -> "DISABLED"
520           null -> "UNKNOWN"
521         }
522     output.add("$flag $flagState ${symbol.toPrettyString()}")
523   }
524   output.sort()
525   return output
526 }
527 
mainnull528 fun main(args: Array<String>) = MainCommand().subcommands(CheckCommand(), ListCommand()).main(args)
529