1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.arch.persistence.room.parser 18 19 import android.arch.persistence.room.ColumnInfo 20 import org.antlr.v4.runtime.ANTLRInputStream 21 import org.antlr.v4.runtime.BaseErrorListener 22 import org.antlr.v4.runtime.CommonTokenStream 23 import org.antlr.v4.runtime.RecognitionException 24 import org.antlr.v4.runtime.Recognizer 25 import org.antlr.v4.runtime.tree.ParseTree 26 import org.antlr.v4.runtime.tree.TerminalNode 27 import java.util.ArrayList 28 import javax.annotation.processing.ProcessingEnvironment 29 import javax.lang.model.type.TypeKind 30 import javax.lang.model.type.TypeMirror 31 32 class QueryVisitor(val original: String, val syntaxErrors: ArrayList<String>, 33 statement: ParseTree) : SQLiteBaseVisitor<Void?>() { 34 val bindingExpressions = arrayListOf<TerminalNode>() 35 // table name alias mappings 36 val tableNames = mutableSetOf<Table>() 37 val withClauseNames = mutableSetOf<String>() 38 val queryType: QueryType 39 40 init { <lambda>null41 queryType = (0..statement.childCount - 1).map { 42 findQueryType(statement.getChild(it)) 43 }.filterNot { it == QueryType.UNKNOWN }.firstOrNull() ?: QueryType.UNKNOWN 44 45 statement.accept(this) 46 } 47 findQueryTypenull48 private fun findQueryType(statement: ParseTree): QueryType { 49 return when (statement) { 50 is SQLiteParser.Factored_select_stmtContext, 51 is SQLiteParser.Compound_select_stmtContext, 52 is SQLiteParser.Select_stmtContext, 53 is SQLiteParser.Simple_select_stmtContext -> 54 QueryType.SELECT 55 56 is SQLiteParser.Delete_stmt_limitedContext, 57 is SQLiteParser.Delete_stmtContext -> 58 QueryType.DELETE 59 60 is SQLiteParser.Insert_stmtContext -> 61 QueryType.INSERT 62 is SQLiteParser.Update_stmtContext, 63 is SQLiteParser.Update_stmt_limitedContext -> 64 QueryType.UPDATE 65 is TerminalNode -> when (statement.text) { 66 "EXPLAIN" -> QueryType.EXPLAIN 67 else -> QueryType.UNKNOWN 68 } 69 else -> QueryType.UNKNOWN 70 } 71 } 72 visitExprnull73 override fun visitExpr(ctx: SQLiteParser.ExprContext): Void? { 74 val bindParameter = ctx.BIND_PARAMETER() 75 if (bindParameter != null) { 76 bindingExpressions.add(bindParameter) 77 } 78 return super.visitExpr(ctx) 79 } 80 createParsedQuerynull81 fun createParsedQuery(): ParsedQuery { 82 return ParsedQuery(original, 83 queryType, 84 bindingExpressions.sortedBy { it.sourceInterval.a }, 85 tableNames, 86 syntaxErrors) 87 } 88 visitCommon_table_expressionnull89 override fun visitCommon_table_expression(ctx: SQLiteParser.Common_table_expressionContext): Void? { 90 val tableName = ctx.table_name()?.text 91 if (tableName != null) { 92 withClauseNames.add(unescapeIdentifier(tableName)) 93 } 94 return super.visitCommon_table_expression(ctx) 95 } 96 visitTable_or_subquerynull97 override fun visitTable_or_subquery(ctx: SQLiteParser.Table_or_subqueryContext): Void? { 98 val tableName = ctx.table_name()?.text 99 if (tableName != null) { 100 val tableAlias = ctx.table_alias()?.text 101 if (tableName !in withClauseNames) { 102 tableNames.add(Table(unescapeIdentifier(tableName), 103 unescapeIdentifier(tableAlias ?: tableName))) 104 } 105 } 106 return super.visitTable_or_subquery(ctx) 107 } 108 unescapeIdentifiernull109 private fun unescapeIdentifier(text : String) : String { 110 val trimmed = text.trim() 111 if (trimmed.startsWith("`") && trimmed.endsWith('`')) { 112 return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1)) 113 } 114 if (trimmed.startsWith('"') && trimmed.endsWith('"')) { 115 return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1)) 116 } 117 return trimmed 118 } 119 } 120 121 class SqlParser { 122 companion object { parsenull123 fun parse(input: String): ParsedQuery { 124 val inputStream = ANTLRInputStream(input) 125 val lexer = SQLiteLexer(inputStream) 126 val tokenStream = CommonTokenStream(lexer) 127 val parser = SQLiteParser(tokenStream) 128 val syntaxErrors = arrayListOf<String>() 129 parser.addErrorListener(object : BaseErrorListener() { 130 override fun syntaxError(recognizer: Recognizer<*, *>, offendingSymbol: Any, 131 line: Int, charPositionInLine: Int, msg: String, 132 e: RecognitionException?) { 133 syntaxErrors.add(msg) 134 } 135 }) 136 try { 137 val parsed = parser.parse() 138 val statementList = parsed.sql_stmt_list() 139 if (statementList.isEmpty()) { 140 syntaxErrors.add(ParserErrors.NOT_ONE_QUERY) 141 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(), 142 listOf(ParserErrors.NOT_ONE_QUERY)) 143 } 144 val statements = statementList.first().children 145 .filter { it is SQLiteParser.Sql_stmtContext } 146 if (statements.size != 1) { 147 syntaxErrors.add(ParserErrors.NOT_ONE_QUERY) 148 } 149 val statement = statements.first() 150 return QueryVisitor(input, syntaxErrors, statement).createParsedQuery() 151 } catch (antlrError: RuntimeException) { 152 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(), 153 listOf("unknown error while parsing $input : ${antlrError.message}")) 154 } 155 } 156 } 157 } 158 159 enum class QueryType { 160 UNKNOWN, 161 SELECT, 162 DELETE, 163 UPDATE, 164 EXPLAIN, 165 INSERT; 166 167 companion object { 168 // IF you change this, don't forget to update @Query documentation. 169 val SUPPORTED = hashSetOf(SELECT, DELETE, UPDATE) 170 } 171 } 172 173 enum class SQLTypeAffinity { 174 NULL, 175 TEXT, 176 INTEGER, 177 REAL, 178 BLOB; getTypeMirrorsnull179 fun getTypeMirrors(env : ProcessingEnvironment) : List<TypeMirror>? { 180 val typeUtils = env.typeUtils 181 return when(this) { 182 TEXT -> listOf(env.elementUtils.getTypeElement("java.lang.String").asType()) 183 INTEGER -> withBoxedTypes(env, TypeKind.INT, TypeKind.BYTE, TypeKind.CHAR, 184 TypeKind.BOOLEAN, TypeKind.LONG, TypeKind.SHORT) 185 REAL -> withBoxedTypes(env, TypeKind.DOUBLE, TypeKind.FLOAT) 186 BLOB -> listOf(typeUtils.getArrayType( 187 typeUtils.getPrimitiveType(TypeKind.BYTE))) 188 else -> emptyList() 189 } 190 } 191 withBoxedTypesnull192 private fun withBoxedTypes(env : ProcessingEnvironment, vararg primitives : TypeKind) : 193 List<TypeMirror> { 194 return primitives.flatMap { 195 val primitiveType = env.typeUtils.getPrimitiveType(it) 196 listOf(primitiveType, env.typeUtils.boxedClass(primitiveType).asType()) 197 } 198 } 199 200 companion object { 201 // converts from ColumnInfo#SQLiteTypeAffinity fromAnnotationValuenull202 fun fromAnnotationValue(value : Int) : SQLTypeAffinity? { 203 return when(value) { 204 ColumnInfo.BLOB -> BLOB 205 ColumnInfo.INTEGER -> INTEGER 206 ColumnInfo.REAL -> REAL 207 ColumnInfo.TEXT -> TEXT 208 else -> null 209 } 210 } 211 } 212 } 213 214 enum class Collate { 215 BINARY, 216 NOCASE, 217 RTRIM; 218 219 companion object { fromAnnotationValuenull220 fun fromAnnotationValue(value: Int): Collate? { 221 return when (value) { 222 ColumnInfo.BINARY -> BINARY 223 ColumnInfo.NOCASE -> NOCASE 224 ColumnInfo.RTRIM -> RTRIM 225 else -> null 226 } 227 } 228 } 229 } 230