1 /*
<lambda>null2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.parser
18 
19 import androidx.room.ColumnInfo
20 import androidx.room.compiler.codegen.XTypeName
21 import androidx.room.compiler.processing.XProcessingEnv
22 import androidx.room.compiler.processing.XType
23 import androidx.room.ext.CommonTypeNames
24 import androidx.room.parser.expansion.isCoreSelect
25 import java.util.Locale
26 import org.antlr.v4.runtime.tree.ParseTree
27 import org.antlr.v4.runtime.tree.TerminalNode
28 
29 class QueryVisitor(
30     private val original: String,
31     private val syntaxErrors: List<String>,
32     statement: ParseTree
33 ) : SQLiteParserBaseVisitor<Void?>() {
34     private val bindingExpressions = arrayListOf<BindParameterNode>()
35     // table name alias mappings
36     private val tableNames = mutableSetOf<Table>()
37     private val withClauseNames = mutableSetOf<String>()
38     private val queryType: QueryType
39     private var foundTopLevelStarProjection: Boolean = false
40 
41     init {
42         queryType =
43             (0 until statement.childCount)
44                 .map { findQueryType(statement.getChild(it)) }
45                 .filterNot { it == QueryType.UNKNOWN }
46                 .firstOrNull() ?: QueryType.UNKNOWN
47         statement.accept(this)
48     }
49 
50     private fun findQueryType(statement: ParseTree): QueryType {
51         return when (statement) {
52             is SQLiteParser.Select_stmtContext -> QueryType.SELECT
53             is SQLiteParser.Delete_stmt_limitedContext,
54             is SQLiteParser.Delete_stmtContext -> QueryType.DELETE
55             is SQLiteParser.Insert_stmtContext -> QueryType.INSERT
56             is SQLiteParser.Update_stmtContext,
57             is SQLiteParser.Update_stmt_limitedContext -> QueryType.UPDATE
58             is TerminalNode ->
59                 when (statement.text) {
60                     "EXPLAIN" -> QueryType.EXPLAIN
61                     else -> QueryType.UNKNOWN
62                 }
63             else -> QueryType.UNKNOWN
64         }
65     }
66 
67     override fun visitExpr(ctx: SQLiteParser.ExprContext): Void? {
68         val bindParameter = ctx.BIND_PARAMETER()
69         if (bindParameter != null) {
70             val parentContext = ctx.parent
71             val isMultiple =
72                 parentContext is SQLiteParser.Comma_separated_exprContext &&
73                     !isFixedParamFunctionExpr(parentContext)
74             bindingExpressions.add(BindParameterNode(node = bindParameter, isMultiple = isMultiple))
75         }
76         return super.visitExpr(ctx)
77     }
78 
79     override fun visitResult_column(ctx: SQLiteParser.Result_columnContext): Void? {
80         if (ctx.parent.isCoreSelect && ctx.text == "*" || ctx.text.endsWith(".*")) {
81             foundTopLevelStarProjection = true
82         }
83         return super.visitResult_column(ctx)
84     }
85 
86     /**
87      * Check if a comma separated expression (where multiple binding parameters are accepted) is
88      * part of a function expression that receives a fixed number of parameters. This is important
89      * for determining the priority of type converters used when binding a collection into a binding
90      * parameters and specifically if the function takes a fixed number of parameter, the collection
91      * should not be expanded.
92      */
93     private fun isFixedParamFunctionExpr(ctx: SQLiteParser.Comma_separated_exprContext): Boolean {
94         if (ctx.parent is SQLiteParser.ExprContext) {
95             val parentExpr = ctx.parent as SQLiteParser.ExprContext
96             val functionName = parentExpr.function_name() ?: return false
97             return fixedParamFunctions.contains(functionName.text.lowercase(Locale.US))
98         } else {
99             return false
100         }
101     }
102 
103     fun createParsedQuery(): ParsedQuery {
104         return ParsedQuery(
105             original = original,
106             type = queryType,
107             inputs = bindingExpressions.sortedBy { it.sourceInterval.a },
108             tables = tableNames,
109             hasTopStarProjection =
110                 if (queryType == QueryType.SELECT) foundTopLevelStarProjection else null,
111             syntaxErrors = syntaxErrors,
112         )
113     }
114 
115     override fun visitCommon_table_expression(
116         ctx: SQLiteParser.Common_table_expressionContext
117     ): Void? {
118         val tableName = ctx.table_name()?.text
119         if (tableName != null) {
120             withClauseNames.add(unescapeIdentifier(tableName))
121         }
122         return super.visitCommon_table_expression(ctx)
123     }
124 
125     override fun visitTable_or_subquery(ctx: SQLiteParser.Table_or_subqueryContext): Void? {
126         val tableName = ctx.table_name()?.text
127         if (tableName != null) {
128             val tableAlias = ctx.table_alias()?.text
129             if (tableName !in withClauseNames) {
130                 tableNames.add(
131                     Table(
132                         unescapeIdentifier(tableName),
133                         unescapeIdentifier(tableAlias ?: tableName)
134                     )
135                 )
136             }
137         }
138         return super.visitTable_or_subquery(ctx)
139     }
140 
141     private fun unescapeIdentifier(text: String): String {
142         val trimmed = text.trim()
143         ESCAPE_LITERALS.forEach {
144             if (trimmed.startsWith(it) && trimmed.endsWith(it)) {
145                 return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1))
146             }
147         }
148         return trimmed
149     }
150 
151     companion object {
152         private val ESCAPE_LITERALS = listOf("\"", "'", "`")
153 
154         // List of built-in SQLite functions that take a fixed non-zero number of parameters
155         // See: https://sqlite.org/lang_corefunc.html
156         val fixedParamFunctions =
157             setOf(
158                 "abs",
159                 "glob",
160                 "hex",
161                 "ifnull",
162                 "iif",
163                 "instr",
164                 "length",
165                 "like",
166                 "likelihood",
167                 "likely",
168                 "load_extension",
169                 "lower",
170                 "ltrim",
171                 "nullif",
172                 "quote",
173                 "randomblob",
174                 "replace",
175                 "round",
176                 "rtrim",
177                 "soundex",
178                 "sqlite_compileoption_get",
179                 "sqlite_compileoption_used",
180                 "sqlite_offset",
181                 "substr",
182                 "trim",
183                 "typeof",
184                 "unicode",
185                 "unlikely",
186                 "upper",
187                 "zeroblob"
188             )
189     }
190 }
191 
192 class SqlParser {
193     companion object {
194         private val INVALID_IDENTIFIER_CHARS = arrayOf('`', '\"')
195 
parsenull196         fun parse(input: String) =
197             SingleQuerySqlParser.parse(
198                 input = input,
199                 visit = { statement, syntaxErrors ->
200                     QueryVisitor(
201                             original = input,
202                             syntaxErrors = syntaxErrors,
203                             statement = statement
204                         )
205                         .createParsedQuery()
206                 },
syntaxErrorsnull207                 fallback = { syntaxErrors ->
208                     ParsedQuery(
209                         original = input,
210                         type = QueryType.UNKNOWN,
211                         inputs = emptyList(),
212                         tables = emptySet(),
213                         hasTopStarProjection = null,
214                         syntaxErrors = syntaxErrors,
215                     )
216                 }
217             )
218 
isValidIdentifiernull219         fun isValidIdentifier(input: String): Boolean =
220             input.isNotBlank() && INVALID_IDENTIFIER_CHARS.none { input.contains(it) }
221 
222         /** creates a no-op select query for raw queries that queries the given list of tables. */
rawQueryForTablesnull223         fun rawQueryForTables(tableNames: Set<String>): ParsedQuery {
224             return ParsedQuery(
225                 original = "raw query",
226                 type = QueryType.UNKNOWN,
227                 inputs = emptyList(),
228                 tables = tableNames.map { Table(name = it, alias = it) }.toSet(),
229                 hasTopStarProjection = null,
230                 syntaxErrors = emptyList(),
231             )
232         }
233     }
234 }
235 
236 data class BindParameterNode(
237     private val node: TerminalNode,
238     val isMultiple: Boolean // true if this is a multi-param node
239 ) : TerminalNode by node
240 
241 enum class QueryType {
242     UNKNOWN,
243     SELECT,
244     DELETE,
245     UPDATE,
246     EXPLAIN,
247     INSERT;
248 
249     companion object {
250         // IF you change this, don't forget to update @Query documentation.
251         val SUPPORTED = hashSetOf(SELECT, DELETE, UPDATE, INSERT)
252     }
253 }
254 
255 enum class SQLTypeAffinity {
256     NULL,
257     TEXT,
258     INTEGER,
259     REAL,
260     BLOB;
261 
getTypeMirrorsnull262     fun getTypeMirrors(env: XProcessingEnv): List<XType>? {
263         return when (this) {
264             TEXT -> withBoxedAndNullableTypes(env, CommonTypeNames.STRING)
265             INTEGER ->
266                 withBoxedAndNullableTypes(
267                     env,
268                     XTypeName.PRIMITIVE_INT,
269                     XTypeName.PRIMITIVE_BYTE,
270                     XTypeName.PRIMITIVE_CHAR,
271                     XTypeName.PRIMITIVE_LONG,
272                     XTypeName.PRIMITIVE_SHORT
273                 )
274             REAL ->
275                 withBoxedAndNullableTypes(
276                     env,
277                     XTypeName.PRIMITIVE_DOUBLE,
278                     XTypeName.PRIMITIVE_FLOAT
279                 )
280             BLOB -> withBoxedAndNullableTypes(env, XTypeName.getArrayName(XTypeName.PRIMITIVE_BYTE))
281             else -> null
282         }
283     }
284 
285     /**
286      * produce acceptable variations of the given type names. For JAVAC:
287      * - If it is primitive, we'll add boxed version For KSP:
288      * - We'll add a nullable version
289      */
withBoxedAndNullableTypesnull290     private fun withBoxedAndNullableTypes(
291         env: XProcessingEnv,
292         vararg typeNames: XTypeName
293     ): List<XType> {
294         return typeNames
295             .flatMap { typeName ->
296                 sequence {
297                     val type = env.requireType(typeName)
298                     yield(type)
299                     if (env.backend == XProcessingEnv.Backend.KSP) {
300                         yield(type.makeNullable())
301                     } else if (typeName.isPrimitive) {
302                         yield(type.boxed())
303                     }
304                 }
305             }
306             .toList()
307     }
308 
309     companion object {
fromAnnotationValuenull310         fun fromAnnotationValue(value: Int?): SQLTypeAffinity? {
311             return when (value) {
312                 ColumnInfo.BLOB -> BLOB
313                 ColumnInfo.INTEGER -> INTEGER
314                 ColumnInfo.REAL -> REAL
315                 ColumnInfo.TEXT -> TEXT
316                 else -> null
317             }
318         }
319     }
320 }
321 
322 enum class Collate {
323     BINARY,
324     NOCASE,
325     RTRIM,
326     LOCALIZED,
327     UNICODE;
328 
329     companion object {
fromAnnotationValuenull330         fun fromAnnotationValue(value: Int?): Collate? {
331             return when (value) {
332                 ColumnInfo.BINARY -> BINARY
333                 ColumnInfo.NOCASE -> NOCASE
334                 ColumnInfo.RTRIM -> RTRIM
335                 ColumnInfo.LOCALIZED -> LOCALIZED
336                 ColumnInfo.UNICODE -> UNICODE
337                 else -> null
338             }
339         }
340     }
341 }
342 
343 enum class FtsVersion {
344     FTS3,
345     FTS4
346 }
347