• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.arch.persistence.room.parser
18 
19 import android.arch.persistence.room.ColumnInfo
20 import org.antlr.v4.runtime.ANTLRInputStream
21 import org.antlr.v4.runtime.BaseErrorListener
22 import org.antlr.v4.runtime.CommonTokenStream
23 import org.antlr.v4.runtime.RecognitionException
24 import org.antlr.v4.runtime.Recognizer
25 import org.antlr.v4.runtime.tree.ParseTree
26 import org.antlr.v4.runtime.tree.TerminalNode
27 import java.util.ArrayList
28 import javax.annotation.processing.ProcessingEnvironment
29 import javax.lang.model.type.TypeKind
30 import javax.lang.model.type.TypeMirror
31 
32 class QueryVisitor(val original: String, val syntaxErrors: ArrayList<String>,
33                    statement: ParseTree) : SQLiteBaseVisitor<Void?>() {
34     val bindingExpressions = arrayListOf<TerminalNode>()
35     // table name alias mappings
36     val tableNames = mutableSetOf<Table>()
37     val withClauseNames = mutableSetOf<String>()
38     val queryType: QueryType
39 
40     init {
<lambda>null41         queryType = (0..statement.childCount - 1).map {
42             findQueryType(statement.getChild(it))
43         }.filterNot { it == QueryType.UNKNOWN }.firstOrNull() ?: QueryType.UNKNOWN
44 
45         statement.accept(this)
46     }
47 
findQueryTypenull48     private fun findQueryType(statement: ParseTree): QueryType {
49         return when (statement) {
50             is SQLiteParser.Factored_select_stmtContext,
51             is SQLiteParser.Compound_select_stmtContext,
52             is SQLiteParser.Select_stmtContext,
53             is SQLiteParser.Simple_select_stmtContext ->
54                 QueryType.SELECT
55 
56             is SQLiteParser.Delete_stmt_limitedContext,
57             is SQLiteParser.Delete_stmtContext ->
58                 QueryType.DELETE
59 
60             is SQLiteParser.Insert_stmtContext ->
61                 QueryType.INSERT
62             is SQLiteParser.Update_stmtContext,
63             is SQLiteParser.Update_stmt_limitedContext ->
64                 QueryType.UPDATE
65             is TerminalNode -> when (statement.text) {
66                 "EXPLAIN" -> QueryType.EXPLAIN
67                 else -> QueryType.UNKNOWN
68             }
69             else -> QueryType.UNKNOWN
70         }
71     }
72 
visitExprnull73     override fun visitExpr(ctx: SQLiteParser.ExprContext): Void? {
74         val bindParameter = ctx.BIND_PARAMETER()
75         if (bindParameter != null) {
76             bindingExpressions.add(bindParameter)
77         }
78         return super.visitExpr(ctx)
79     }
80 
createParsedQuerynull81     fun createParsedQuery(): ParsedQuery {
82         return ParsedQuery(original,
83                 queryType,
84                 bindingExpressions.sortedBy { it.sourceInterval.a },
85                 tableNames,
86                 syntaxErrors)
87     }
88 
visitCommon_table_expressionnull89     override fun visitCommon_table_expression(ctx: SQLiteParser.Common_table_expressionContext): Void? {
90         val tableName = ctx.table_name()?.text
91         if (tableName != null) {
92             withClauseNames.add(unescapeIdentifier(tableName))
93         }
94         return super.visitCommon_table_expression(ctx)
95     }
96 
visitTable_or_subquerynull97     override fun visitTable_or_subquery(ctx: SQLiteParser.Table_or_subqueryContext): Void? {
98         val tableName = ctx.table_name()?.text
99         if (tableName != null) {
100             val tableAlias = ctx.table_alias()?.text
101             if (tableName !in withClauseNames) {
102                 tableNames.add(Table(unescapeIdentifier(tableName),
103                                      unescapeIdentifier(tableAlias ?: tableName)))
104             }
105         }
106         return super.visitTable_or_subquery(ctx)
107     }
108 
unescapeIdentifiernull109     private fun unescapeIdentifier(text : String) : String {
110         val trimmed = text.trim()
111         if (trimmed.startsWith("`") && trimmed.endsWith('`')) {
112             return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1))
113         }
114         if (trimmed.startsWith('"') && trimmed.endsWith('"')) {
115             return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1))
116         }
117         return trimmed
118     }
119 }
120 
121 class SqlParser {
122     companion object {
parsenull123         fun parse(input: String): ParsedQuery {
124             val inputStream = ANTLRInputStream(input)
125             val lexer = SQLiteLexer(inputStream)
126             val tokenStream = CommonTokenStream(lexer)
127             val parser = SQLiteParser(tokenStream)
128             val syntaxErrors = arrayListOf<String>()
129             parser.addErrorListener(object : BaseErrorListener() {
130                 override fun syntaxError(recognizer: Recognizer<*, *>, offendingSymbol: Any,
131                                          line: Int, charPositionInLine: Int, msg: String,
132                                          e: RecognitionException?) {
133                     syntaxErrors.add(msg)
134                 }
135             })
136             try {
137                 val parsed = parser.parse()
138                 val statementList = parsed.sql_stmt_list()
139                 if (statementList.isEmpty()) {
140                     syntaxErrors.add(ParserErrors.NOT_ONE_QUERY)
141                     return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(),
142                             listOf(ParserErrors.NOT_ONE_QUERY))
143                 }
144                 val statements = statementList.first().children
145                         .filter { it is SQLiteParser.Sql_stmtContext }
146                 if (statements.size != 1) {
147                     syntaxErrors.add(ParserErrors.NOT_ONE_QUERY)
148                 }
149                 val statement = statements.first()
150                 return QueryVisitor(input, syntaxErrors, statement).createParsedQuery()
151             } catch (antlrError: RuntimeException) {
152                 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(),
153                         listOf("unknown error while parsing $input : ${antlrError.message}"))
154             }
155         }
156     }
157 }
158 
159 enum class QueryType {
160     UNKNOWN,
161     SELECT,
162     DELETE,
163     UPDATE,
164     EXPLAIN,
165     INSERT;
166 
167     companion object {
168         // IF you change this, don't forget to update @Query documentation.
169         val SUPPORTED = hashSetOf(SELECT, DELETE, UPDATE)
170     }
171 }
172 
173 enum class SQLTypeAffinity {
174     NULL,
175     TEXT,
176     INTEGER,
177     REAL,
178     BLOB;
getTypeMirrorsnull179     fun getTypeMirrors(env : ProcessingEnvironment) : List<TypeMirror>? {
180         val typeUtils = env.typeUtils
181         return when(this) {
182             TEXT -> listOf(env.elementUtils.getTypeElement("java.lang.String").asType())
183             INTEGER -> withBoxedTypes(env, TypeKind.INT, TypeKind.BYTE, TypeKind.CHAR,
184                     TypeKind.BOOLEAN, TypeKind.LONG, TypeKind.SHORT)
185             REAL -> withBoxedTypes(env, TypeKind.DOUBLE, TypeKind.FLOAT)
186             BLOB -> listOf(typeUtils.getArrayType(
187                     typeUtils.getPrimitiveType(TypeKind.BYTE)))
188             else -> emptyList()
189         }
190     }
191 
withBoxedTypesnull192     private fun withBoxedTypes(env : ProcessingEnvironment, vararg primitives : TypeKind) :
193             List<TypeMirror> {
194         return primitives.flatMap {
195             val primitiveType = env.typeUtils.getPrimitiveType(it)
196             listOf(primitiveType, env.typeUtils.boxedClass(primitiveType).asType())
197         }
198     }
199 
200     companion object {
201         // converts from ColumnInfo#SQLiteTypeAffinity
fromAnnotationValuenull202         fun fromAnnotationValue(value : Int) : SQLTypeAffinity? {
203             return when(value) {
204                 ColumnInfo.BLOB -> BLOB
205                 ColumnInfo.INTEGER -> INTEGER
206                 ColumnInfo.REAL -> REAL
207                 ColumnInfo.TEXT -> TEXT
208                 else -> null
209             }
210         }
211     }
212 }
213 
214 enum class Collate {
215     BINARY,
216     NOCASE,
217     RTRIM;
218 
219     companion object {
fromAnnotationValuenull220         fun fromAnnotationValue(value: Int): Collate? {
221             return when (value) {
222                 ColumnInfo.BINARY -> BINARY
223                 ColumnInfo.NOCASE -> NOCASE
224                 ColumnInfo.RTRIM -> RTRIM
225                 else -> null
226             }
227         }
228     }
229 }
230