1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "sql_analyzer.h"
16 #include "log.h"
17
18 namespace OHOS::Request::Download {
19 constexpr int POS_ADD_TWO = 2;
20
SqlAnalyzer()21 SqlAnalyzer::SqlAnalyzer()
22 {
23 }
24
~SqlAnalyzer()25 SqlAnalyzer::~SqlAnalyzer()
26 {
27 }
28
CheckValuesBucket(const NativeRdb::ValuesBucket & value)29 bool SqlAnalyzer::CheckValuesBucket(const NativeRdb::ValuesBucket &value)
30 {
31 std::map<std::string, NativeRdb::ValueObject> valuesMap;
32 value.GetAll(valuesMap);
33 for (auto it = valuesMap.begin(); it != valuesMap.end(); ++it) {
34 std::string key = it->first;
35 bool isKey = FindIllegalWords(key);
36 if (isKey) {
37 DOWNLOAD_HILOGE("SqlAnalyzer CheckValuesBucket key is %{public}s error", key.c_str());
38 return false;
39 }
40 NativeRdb::ValueObject value = it->second;
41 if (value.GetType() == NativeRdb::ValueObjectType::TYPE_STRING) {
42 std::string str;
43 value.GetString(str);
44 bool isValue = FindIllegalWords(str);
45 if (isValue) {
46 DOWNLOAD_HILOGE("SqlAnalyzer CheckValuesBucket value is %{public}s error", str.c_str());
47 return false;
48 }
49 }
50 }
51 return true;
52 }
53
CharCheck(char & ch,std::string sql,std::size_t & pos)54 bool SqlAnalyzer::CharCheck(char &ch, std::string sql, std::size_t &pos)
55 {
56 if (ch == '[') {
57 pos++;
58 std::size_t found = sql.find(']', pos);
59 if (found == std::string::npos) {
60 return true;
61 }
62 pos++;
63 }
64 if (ch == '-' && PickChar(sql, pos + 1) == '-') {
65 pos += POS_ADD_TWO;
66 std::size_t found = sql.find('\n', pos);
67 if (found == std::string::npos) {
68 return true;
69 }
70 pos++;
71 }
72 if (ch == '/' && PickChar(sql, pos + 1) == '*') {
73 pos += POS_ADD_TWO;
74 std::size_t found = sql.find("*/", pos);
75 if (found == std::string::npos) {
76 return true;
77 }
78 pos += POS_ADD_TWO;
79 }
80 if (ch == ';') {
81 return true;
82 }
83 pos++;
84 return false;
85 }
86
StrCheck(char & ch,std::size_t strlen,std::string sql,std::size_t & pos)87 bool SqlAnalyzer::StrCheck(char &ch, std::size_t strlen, std::string sql, std::size_t &pos)
88 {
89 if (IsInStr(ch, "'\"`") == 0) {
90 pos++;
91 while (pos < strlen) {
92 std::size_t found = sql.find(ch, pos);
93 if (found == std::string::npos) {
94 return true;
95 }
96 if (PickChar(sql, pos + 1) != ch) {
97 break;
98 }
99 pos += POS_ADD_TWO;
100 }
101 }
102 return false;
103 }
104
FindIllegalWords(std::string sql)105 bool SqlAnalyzer::FindIllegalWords(std::string sql)
106 {
107 if (sql.empty()) {
108 return false;
109 }
110 std::size_t pos = 0;
111 std::size_t strlen = sql.length();
112 while (pos < strlen) {
113 char ch = PickChar(sql, pos);
114 if (IsLetter(ch)) {
115 std::size_t start = pos;
116 pos++;
117 while (IsLetterNumber(PickChar(sql, pos))) {
118 pos++;
119 }
120 std::size_t count = pos - start + 1;
121 sql.substr(start, count);
122 }
123 if (IsInStr(ch, "'\"`") == 0) {
124 if (StrCheck(ch, strlen, sql, pos)) {
125 return true;
126 } else {
127 continue;
128 }
129 }
130 if (CharCheck(ch, sql, pos)) {
131 return true;
132 } else {
133 continue;
134 }
135 }
136 return false;
137 }
138 } // namespace OHOS::Request::Download