1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/common/sqlite_utils.h"
6
7 #include <list>
8
9 #include "base/file_path.h"
10 #include "base/lazy_instance.h"
11 #include "base/logging.h"
12 #include "base/stl_util-inl.h"
13 #include "base/string16.h"
14 #include "base/synchronization/lock.h"
15
16 // The vanilla error handler implements the common fucntionality for all the
17 // error handlers. Specialized error handlers are expected to only override
18 // the Handler() function.
19 class VanillaSQLErrorHandler : public SQLErrorHandler {
20 public:
VanillaSQLErrorHandler()21 VanillaSQLErrorHandler() : error_(SQLITE_OK) {
22 }
GetLastError() const23 virtual int GetLastError() const {
24 return error_;
25 }
26 protected:
27 int error_;
28 };
29
30 class DebugSQLErrorHandler: public VanillaSQLErrorHandler {
31 public:
HandleError(int error,sqlite3 * db)32 virtual int HandleError(int error, sqlite3* db) {
33 error_ = error;
34 NOTREACHED() << "sqlite error " << error
35 << " db " << static_cast<void*>(db);
36 return error;
37 }
38 };
39
40 class ReleaseSQLErrorHandler : public VanillaSQLErrorHandler {
41 public:
HandleError(int error,sqlite3 * db)42 virtual int HandleError(int error, sqlite3* db) {
43 error_ = error;
44 // Used to have a CHECK here. Got lots of crashes.
45 return error;
46 }
47 };
48
49 // The default error handler factory is also in charge of managing the
50 // lifetime of the error objects. This object is multi-thread safe.
51 class DefaultSQLErrorHandlerFactory : public SQLErrorHandlerFactory {
52 public:
~DefaultSQLErrorHandlerFactory()53 ~DefaultSQLErrorHandlerFactory() {
54 STLDeleteContainerPointers(errors_.begin(), errors_.end());
55 }
56
Make()57 virtual SQLErrorHandler* Make() {
58 SQLErrorHandler* handler;
59 #ifndef NDEBUG
60 handler = new DebugSQLErrorHandler;
61 #else
62 handler = new ReleaseSQLErrorHandler;
63 #endif // NDEBUG
64 AddHandler(handler);
65 return handler;
66 }
67
68 private:
AddHandler(SQLErrorHandler * handler)69 void AddHandler(SQLErrorHandler* handler) {
70 base::AutoLock lock(lock_);
71 errors_.push_back(handler);
72 }
73
74 typedef std::list<SQLErrorHandler*> ErrorList;
75 ErrorList errors_;
76 base::Lock lock_;
77 };
78
79 static base::LazyInstance<DefaultSQLErrorHandlerFactory>
80 g_default_sql_error_handler_factory(base::LINKER_INITIALIZED);
81
GetErrorHandlerFactory()82 SQLErrorHandlerFactory* GetErrorHandlerFactory() {
83 // TODO(cpu): Testing needs to override the error handler.
84 // Destruction of DefaultSQLErrorHandlerFactory handled by at_exit manager.
85 return g_default_sql_error_handler_factory.Pointer();
86 }
87
88 namespace sqlite_utils {
89
OpenSqliteDb(const FilePath & filepath,sqlite3 ** database)90 int OpenSqliteDb(const FilePath& filepath, sqlite3** database) {
91 #if defined(OS_WIN)
92 // We want the default encoding to always be UTF-8, so we use the
93 // 8-bit version of open().
94 return sqlite3_open(WideToUTF8(filepath.value()).c_str(), database);
95 #elif defined(OS_POSIX)
96 return sqlite3_open(filepath.value().c_str(), database);
97 #endif
98 }
99
DoesSqliteTableExist(sqlite3 * db,const char * db_name,const char * table_name)100 bool DoesSqliteTableExist(sqlite3* db,
101 const char* db_name,
102 const char* table_name) {
103 // sqlite doesn't allow binding parameters as table names, so we have to
104 // manually construct the sql
105 std::string sql("SELECT name FROM ");
106 if (db_name && db_name[0]) {
107 sql.append(db_name);
108 sql.push_back('.');
109 }
110 sql.append("sqlite_master WHERE type='table' AND name=?");
111
112 SQLStatement statement;
113 if (statement.prepare(db, sql.c_str()) != SQLITE_OK)
114 return false;
115
116 if (statement.bind_text(0, table_name) != SQLITE_OK)
117 return false;
118
119 // we only care about if this matched a row, not the actual data
120 return sqlite3_step(statement.get()) == SQLITE_ROW;
121 }
122
DoesSqliteColumnExist(sqlite3 * db,const char * database_name,const char * table_name,const char * column_name,const char * column_type)123 bool DoesSqliteColumnExist(sqlite3* db,
124 const char* database_name,
125 const char* table_name,
126 const char* column_name,
127 const char* column_type) {
128 SQLStatement s;
129 std::string sql;
130 sql.append("PRAGMA ");
131 if (database_name && database_name[0]) {
132 // optional database name specified
133 sql.append(database_name);
134 sql.push_back('.');
135 }
136 sql.append("TABLE_INFO(");
137 sql.append(table_name);
138 sql.append(")");
139
140 if (s.prepare(db, sql.c_str()) != SQLITE_OK)
141 return false;
142
143 while (s.step() == SQLITE_ROW) {
144 if (!s.column_string(1).compare(column_name)) {
145 if (column_type && column_type[0])
146 return !s.column_string(2).compare(column_type);
147 return true;
148 }
149 }
150 return false;
151 }
152
DoesSqliteTableHaveRow(sqlite3 * db,const char * table_name)153 bool DoesSqliteTableHaveRow(sqlite3* db, const char* table_name) {
154 SQLStatement s;
155 std::string b;
156 b.append("SELECT * FROM ");
157 b.append(table_name);
158
159 if (s.prepare(db, b.c_str()) != SQLITE_OK)
160 return false;
161
162 return s.step() == SQLITE_ROW;
163 }
164
165 } // namespace sqlite_utils
166
SQLTransaction(sqlite3 * db)167 SQLTransaction::SQLTransaction(sqlite3* db) : db_(db), began_(false) {
168 }
169
~SQLTransaction()170 SQLTransaction::~SQLTransaction() {
171 if (began_) {
172 Rollback();
173 }
174 }
175
BeginCommand(const char * command)176 int SQLTransaction::BeginCommand(const char* command) {
177 int rv = SQLITE_ERROR;
178 if (!began_ && db_) {
179 rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
180 began_ = (rv == SQLITE_OK);
181 }
182 return rv;
183 }
184
EndCommand(const char * command)185 int SQLTransaction::EndCommand(const char* command) {
186 int rv = SQLITE_ERROR;
187 if (began_ && db_) {
188 rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
189 began_ = (rv != SQLITE_OK);
190 }
191 return rv;
192 }
193
~SQLNestedTransactionSite()194 SQLNestedTransactionSite::~SQLNestedTransactionSite() {
195 DCHECK(!top_transaction_);
196 }
197
SetTopTransaction(SQLNestedTransaction * top)198 void SQLNestedTransactionSite::SetTopTransaction(SQLNestedTransaction* top) {
199 DCHECK(!top || !top_transaction_);
200 top_transaction_ = top;
201 }
202
SQLNestedTransaction(SQLNestedTransactionSite * site)203 SQLNestedTransaction::SQLNestedTransaction(SQLNestedTransactionSite* site)
204 : SQLTransaction(site->GetSqlite3DB()),
205 needs_rollback_(false),
206 site_(site) {
207 DCHECK(site);
208 if (site->GetTopTransaction() == NULL) {
209 site->SetTopTransaction(this);
210 }
211 }
212
~SQLNestedTransaction()213 SQLNestedTransaction::~SQLNestedTransaction() {
214 if (began_) {
215 Rollback();
216 }
217 if (site_->GetTopTransaction() == this) {
218 site_->SetTopTransaction(NULL);
219 }
220 }
221
BeginCommand(const char * command)222 int SQLNestedTransaction::BeginCommand(const char* command) {
223 DCHECK(db_);
224 DCHECK(site_ && site_->GetTopTransaction());
225 if (!db_ || began_) {
226 return SQLITE_ERROR;
227 }
228 if (site_->GetTopTransaction() == this) {
229 int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
230 began_ = (rv == SQLITE_OK);
231 if (began_) {
232 site_->OnBegin();
233 }
234 return rv;
235 } else {
236 if (site_->GetTopTransaction()->needs_rollback_) {
237 return SQLITE_ERROR;
238 }
239 began_ = true;
240 return SQLITE_OK;
241 }
242 }
243
EndCommand(const char * command)244 int SQLNestedTransaction::EndCommand(const char* command) {
245 DCHECK(db_);
246 DCHECK(site_ && site_->GetTopTransaction());
247 if (!db_ || !began_) {
248 return SQLITE_ERROR;
249 }
250 if (site_->GetTopTransaction() == this) {
251 if (needs_rollback_) {
252 sqlite3_exec(db_, "ROLLBACK", NULL, NULL, NULL);
253 began_ = false; // reset so we don't try to rollback or call
254 // OnRollback() again
255 site_->OnRollback();
256 return SQLITE_ERROR;
257 } else {
258 int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
259 began_ = (rv != SQLITE_OK);
260 if (strcmp(command, "ROLLBACK") == 0) {
261 began_ = false; // reset so we don't try to rollbck or call
262 // OnRollback() again
263 site_->OnRollback();
264 } else {
265 DCHECK(strcmp(command, "COMMIT") == 0);
266 if (rv == SQLITE_OK) {
267 site_->OnCommit();
268 }
269 }
270 return rv;
271 }
272 } else {
273 if (strcmp(command, "ROLLBACK") == 0) {
274 site_->GetTopTransaction()->needs_rollback_ = true;
275 }
276 began_ = false;
277 return SQLITE_OK;
278 }
279 }
280
prepare(sqlite3 * db,const char * sql,int sql_len)281 int SQLStatement::prepare(sqlite3* db, const char* sql, int sql_len) {
282 DCHECK(!stmt_);
283 int rv = sqlite3_prepare_v2(db, sql, sql_len, &stmt_, NULL);
284 if (rv != SQLITE_OK) {
285 SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
286 return error_handler->HandleError(rv, db);
287 }
288 return rv;
289 }
290
step()291 int SQLStatement::step() {
292 DCHECK(stmt_);
293 int status = sqlite3_step(stmt_);
294 if ((status == SQLITE_ROW) || (status == SQLITE_DONE))
295 return status;
296 // We got a problem.
297 SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
298 return error_handler->HandleError(status, db_handle());
299 }
300
reset()301 int SQLStatement::reset() {
302 DCHECK(stmt_);
303 return sqlite3_reset(stmt_);
304 }
305
last_insert_rowid()306 sqlite_int64 SQLStatement::last_insert_rowid() {
307 DCHECK(stmt_);
308 return sqlite3_last_insert_rowid(db_handle());
309 }
310
changes()311 int SQLStatement::changes() {
312 DCHECK(stmt_);
313 return sqlite3_changes(db_handle());
314 }
315
db_handle()316 sqlite3* SQLStatement::db_handle() {
317 DCHECK(stmt_);
318 return sqlite3_db_handle(stmt_);
319 }
320
bind_parameter_count()321 int SQLStatement::bind_parameter_count() {
322 DCHECK(stmt_);
323 return sqlite3_bind_parameter_count(stmt_);
324 }
325
bind_blob(int index,std::vector<unsigned char> * blob)326 int SQLStatement::bind_blob(int index, std::vector<unsigned char>* blob) {
327 if (blob) {
328 const void* value = blob->empty() ? NULL : &(*blob)[0];
329 int len = static_cast<int>(blob->size());
330 return bind_blob(index, value, len);
331 } else {
332 return bind_null(index);
333 }
334 }
335
bind_blob(int index,const void * value,int value_len)336 int SQLStatement::bind_blob(int index, const void* value, int value_len) {
337 return bind_blob(index, value, value_len, SQLITE_TRANSIENT);
338 }
339
bind_blob(int index,const void * value,int value_len,Function dtor)340 int SQLStatement::bind_blob(int index, const void* value, int value_len,
341 Function dtor) {
342 DCHECK(stmt_);
343 return sqlite3_bind_blob(stmt_, index + 1, value, value_len, dtor);
344 }
345
bind_double(int index,double value)346 int SQLStatement::bind_double(int index, double value) {
347 DCHECK(stmt_);
348 return sqlite3_bind_double(stmt_, index + 1, value);
349 }
350
bind_bool(int index,bool value)351 int SQLStatement::bind_bool(int index, bool value) {
352 DCHECK(stmt_);
353 return sqlite3_bind_int(stmt_, index + 1, value);
354 }
355
bind_int(int index,int value)356 int SQLStatement::bind_int(int index, int value) {
357 DCHECK(stmt_);
358 return sqlite3_bind_int(stmt_, index + 1, value);
359 }
360
bind_int64(int index,sqlite_int64 value)361 int SQLStatement::bind_int64(int index, sqlite_int64 value) {
362 DCHECK(stmt_);
363 return sqlite3_bind_int64(stmt_, index + 1, value);
364 }
365
bind_null(int index)366 int SQLStatement::bind_null(int index) {
367 DCHECK(stmt_);
368 return sqlite3_bind_null(stmt_, index + 1);
369 }
370
bind_text(int index,const char * value,int value_len,Function dtor)371 int SQLStatement::bind_text(int index, const char* value, int value_len,
372 Function dtor) {
373 DCHECK(stmt_);
374 return sqlite3_bind_text(stmt_, index + 1, value, value_len, dtor);
375 }
376
bind_text16(int index,const char16 * value,int value_len,Function dtor)377 int SQLStatement::bind_text16(int index, const char16* value, int value_len,
378 Function dtor) {
379 DCHECK(stmt_);
380 value_len *= sizeof(char16);
381 return sqlite3_bind_text16(stmt_, index + 1, value, value_len, dtor);
382 }
383
bind_value(int index,const sqlite3_value * value)384 int SQLStatement::bind_value(int index, const sqlite3_value* value) {
385 DCHECK(stmt_);
386 return sqlite3_bind_value(stmt_, index + 1, value);
387 }
388
column_count()389 int SQLStatement::column_count() {
390 DCHECK(stmt_);
391 return sqlite3_column_count(stmt_);
392 }
393
column_type(int index)394 int SQLStatement::column_type(int index) {
395 DCHECK(stmt_);
396 return sqlite3_column_type(stmt_, index);
397 }
398
column_blob(int index)399 const void* SQLStatement::column_blob(int index) {
400 DCHECK(stmt_);
401 return sqlite3_column_blob(stmt_, index);
402 }
403
column_blob_as_vector(int index,std::vector<unsigned char> * blob)404 bool SQLStatement::column_blob_as_vector(int index,
405 std::vector<unsigned char>* blob) {
406 DCHECK(stmt_);
407 const void* p = column_blob(index);
408 size_t len = column_bytes(index);
409 blob->resize(len);
410 if (blob->size() != len) {
411 return false;
412 }
413 if (len > 0)
414 memcpy(&(blob->front()), p, len);
415 return true;
416 }
417
column_blob_as_string(int index,std::string * blob)418 bool SQLStatement::column_blob_as_string(int index, std::string* blob) {
419 DCHECK(stmt_);
420 const void* p = column_blob(index);
421 size_t len = column_bytes(index);
422 blob->resize(len);
423 if (blob->size() != len) {
424 return false;
425 }
426 blob->assign(reinterpret_cast<const char*>(p), len);
427 return true;
428 }
429
column_bytes(int index)430 int SQLStatement::column_bytes(int index) {
431 DCHECK(stmt_);
432 return sqlite3_column_bytes(stmt_, index);
433 }
434
column_bytes16(int index)435 int SQLStatement::column_bytes16(int index) {
436 DCHECK(stmt_);
437 return sqlite3_column_bytes16(stmt_, index);
438 }
439
column_double(int index)440 double SQLStatement::column_double(int index) {
441 DCHECK(stmt_);
442 return sqlite3_column_double(stmt_, index);
443 }
444
column_bool(int index)445 bool SQLStatement::column_bool(int index) {
446 DCHECK(stmt_);
447 return sqlite3_column_int(stmt_, index) ? true : false;
448 }
449
column_int(int index)450 int SQLStatement::column_int(int index) {
451 DCHECK(stmt_);
452 return sqlite3_column_int(stmt_, index);
453 }
454
column_int64(int index)455 sqlite_int64 SQLStatement::column_int64(int index) {
456 DCHECK(stmt_);
457 return sqlite3_column_int64(stmt_, index);
458 }
459
column_text(int index)460 const char* SQLStatement::column_text(int index) {
461 DCHECK(stmt_);
462 return reinterpret_cast<const char*>(sqlite3_column_text(stmt_, index));
463 }
464
column_string(int index,std::string * str)465 bool SQLStatement::column_string(int index, std::string* str) {
466 DCHECK(stmt_);
467 DCHECK(str);
468 const char* s = column_text(index);
469 str->assign(s ? s : std::string());
470 return s != NULL;
471 }
472
column_string(int index)473 std::string SQLStatement::column_string(int index) {
474 std::string str;
475 column_string(index, &str);
476 return str;
477 }
478
column_text16(int index)479 const char16* SQLStatement::column_text16(int index) {
480 DCHECK(stmt_);
481 return static_cast<const char16*>(sqlite3_column_text16(stmt_, index));
482 }
483
column_string16(int index,string16 * str)484 bool SQLStatement::column_string16(int index, string16* str) {
485 DCHECK(stmt_);
486 DCHECK(str);
487 const char* s = column_text(index);
488 str->assign(s ? UTF8ToUTF16(s) : string16());
489 return (s != NULL);
490 }
491
column_string16(int index)492 string16 SQLStatement::column_string16(int index) {
493 string16 str;
494 column_string16(index, &str);
495 return str;
496 }
497
column_wstring(int index,std::wstring * str)498 bool SQLStatement::column_wstring(int index, std::wstring* str) {
499 DCHECK(stmt_);
500 DCHECK(str);
501 const char* s = column_text(index);
502 str->assign(s ? UTF8ToWide(s) : std::wstring());
503 return (s != NULL);
504 }
505
column_wstring(int index)506 std::wstring SQLStatement::column_wstring(int index) {
507 std::wstring wstr;
508 column_wstring(index, &wstr);
509 return wstr;
510 }
511