1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "sg_classify_client.h"
17
18 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
19 #include <future>
20 #include "iremote_broker.h"
21 #include "iservice_registry.h"
22 #include "securec.h"
23 #include "risk_analysis_manager_callback_service.h"
24 #include "risk_analysis_manager_proxy.h"
25 #include "risk_analysis_manager.h"
26 #endif
27
28 #include "security_guard_define.h"
29 #include "security_guard_log.h"
30
31 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
32 namespace {
33 constexpr int32_t TIMEOUT_REPLY = 15000;
34 }
35
36 using namespace OHOS;
37 using namespace OHOS::Security::SecurityGuard;
38
39 static std::mutex g_mutex;
40
RequestSecurityModelResult(const std::string & devId,uint32_t modelId,const std::string & param,ResultCallback callback)41 static int32_t RequestSecurityModelResult(const std::string &devId, uint32_t modelId,
42 const std::string ¶m, ResultCallback callback)
43 {
44 auto registry = SystemAbilityManagerClient::GetInstance().GetSystemAbilityManager();
45 if (registry == nullptr) {
46 SGLOGE("GetSystemAbilityManager error");
47 return NULL_OBJECT;
48 }
49
50 auto object = registry->GetSystemAbility(RISK_ANALYSIS_MANAGER_SA_ID);
51 auto proxy = iface_cast<RiskAnalysisManager>(object);
52 if (proxy == nullptr) {
53 SGLOGE("proxy is null");
54 return NULL_OBJECT;
55 }
56
57 sptr<RiskAnalysisManagerCallbackService> stub = new (std::nothrow) RiskAnalysisManagerCallbackService(callback);
58 if (stub == nullptr) {
59 SGLOGE("stub is null");
60 return NULL_OBJECT;
61 }
62 int32_t ret = proxy->RequestSecurityModelResult(devId, modelId, param, stub);
63 SGLOGI("RequestSecurityModelResult result, ret=%{public}d", ret);
64 return ret;
65 }
66 #endif
67 namespace OHOS::Security::SecurityGuard {
RequestSecurityModelResultSync(const std::string & devId,uint32_t modelId,const std::string & param,SecurityModelResult & result)68 int32_t RequestSecurityModelResultSync(const std::string &devId, uint32_t modelId,
69 const std::string ¶m, SecurityModelResult &result)
70 {
71 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
72 if (devId.length() >= DEVICE_ID_MAX_LEN) {
73 return BAD_PARAM;
74 }
75 std::unique_lock<std::mutex> lock(g_mutex);
76 auto promise = std::make_shared<std::promise<SecurityModelResult>>();
77 auto future = promise->get_future();
78 auto func = [promise, param] (const std::string &devId, uint32_t modelId,
79 const std::string &result) mutable -> int32_t {
80 SecurityModelResult modelResult = {
81 .devId = devId,
82 .modelId = modelId,
83 .param = param,
84 .result = result
85 };
86 promise->set_value(modelResult);
87 return SUCCESS;
88 };
89
90 int32_t code = RequestSecurityModelResult(devId, modelId, param, func);
91 if (code != SUCCESS) {
92 SGLOGE("RequestSecurityModelResult error, code=%{public}d", code);
93 return code;
94 }
95 std::chrono::milliseconds span(TIMEOUT_REPLY);
96 if (future.wait_for(span) == std::future_status::timeout) {
97 SGLOGE("wait timeout");
98 return TIME_OUT;
99 }
100 result = future.get();
101 return SUCCESS;
102 #else
103 return 0;
104 #endif
105 }
106
RequestSecurityModelResultAsync(const std::string & devId,uint32_t modelId,const std::string & param,SecurityGuardRiskCallback callback)107 int32_t RequestSecurityModelResultAsync(const std::string &devId, uint32_t modelId,
108 const std::string ¶m, SecurityGuardRiskCallback callback)
109 {
110 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
111 if (devId.length() >= DEVICE_ID_MAX_LEN) {
112 return BAD_PARAM;
113 }
114 std::unique_lock<std::mutex> lock(g_mutex);
115 auto func = [callback, param] (const std::string &devId,
116 uint32_t modelId, const std::string &result) -> int32_t {
117 callback(SecurityModelResult{devId, modelId, param, result});
118 return SUCCESS;
119 };
120
121 return RequestSecurityModelResult(devId, modelId, param, func);
122 #else
123 return 0;
124 #endif
125 }
126
127 // LCOV_EXCL_START
StartSecurityModel(uint32_t modelId,const std::string & param)128 int32_t StartSecurityModel(uint32_t modelId, const std::string ¶m)
129 {
130 SGLOGI("enter StartSecurityModel");
131 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
132 auto registry = SystemAbilityManagerClient::GetInstance().GetSystemAbilityManager();
133 if (registry == nullptr) {
134 SGLOGE("GetSystemAbilityManager error");
135 return NULL_OBJECT;
136 }
137
138 auto object = registry->GetSystemAbility(RISK_ANALYSIS_MANAGER_SA_ID);
139 auto proxy = iface_cast<RiskAnalysisManager>(object);
140 if (proxy == nullptr) {
141 SGLOGE("proxy is null");
142 return NULL_OBJECT;
143 }
144
145 int32_t ret = proxy->StartSecurityModel(modelId, param);
146 SGLOGI("StartSecurityModel result, ret=%{public}d", ret);
147 return ret;
148 #else
149 return 0;
150 #endif
151 }
152 // LCOV_EXCL_STOP
153 }
154
155 #ifdef __cplusplus
156 extern "C" {
157 #endif
158 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
FillingRequestResult(const OHOS::Security::SecurityGuard::SecurityModelResult & cppResult,::SecurityModelResult * result)159 static int32_t FillingRequestResult(const OHOS::Security::SecurityGuard::SecurityModelResult &cppResult,
160 ::SecurityModelResult *result)
161 {
162 if (cppResult.devId.length() >= DEVICE_ID_MAX_LEN || cppResult.result.length() >= RESULT_MAX_LEN) {
163 return BAD_PARAM;
164 }
165
166 result->modelId = cppResult.modelId;
167 errno_t rc = memcpy_s(result->devId.identity, DEVICE_ID_MAX_LEN, cppResult.devId.c_str(), cppResult.devId.length());
168 if (rc != EOK) {
169 return NULL_OBJECT;
170 }
171 result->devId.length = cppResult.devId.length();
172
173 rc = memcpy_s(result->result, RESULT_MAX_LEN, cppResult.result.c_str(), cppResult.result.length());
174 if (rc != EOK) {
175 return NULL_OBJECT;
176 }
177 result->resultLen = cppResult.result.length();
178
179 SGLOGD("modelId=%{public}u, result=%{public}s", cppResult.modelId, cppResult.result.c_str());
180 return SUCCESS;
181 }
182
CovertDevId(const DeviceIdentify * devId)183 static std::string CovertDevId(const DeviceIdentify *devId)
184 {
185 std::vector<char> id(DEVICE_ID_MAX_LEN, '\0');
186 std::copy(&devId->identity[0], &devId->identity[DEVICE_ID_MAX_LEN - 1], id.begin());
187 return std::string{id.data()};
188 }
189 #endif
RequestSecurityModelResultSync(const DeviceIdentify * devId,uint32_t modelId,::SecurityModelResult * result)190 int32_t RequestSecurityModelResultSync(const DeviceIdentify *devId, uint32_t modelId, ::SecurityModelResult *result)
191 {
192 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
193 if (devId == nullptr || result == nullptr || devId->length >= DEVICE_ID_MAX_LEN) {
194 return BAD_PARAM;
195 }
196 OHOS::Security::SecurityGuard::SecurityModelResult tmp;
197 int32_t ret = OHOS::Security::SecurityGuard::RequestSecurityModelResultSync(CovertDevId(devId), modelId, "", tmp);
198 FillingRequestResult(tmp, result);
199 return ret;
200 #else
201 return 0;
202 #endif
203 }
204
RequestSecurityModelResultAsync(const DeviceIdentify * devId,uint32_t modelId,::SecurityGuardRiskCallback callback)205 int32_t RequestSecurityModelResultAsync(const DeviceIdentify *devId, uint32_t modelId,
206 ::SecurityGuardRiskCallback callback)
207 {
208 #ifndef SECURITY_GUARD_TRIM_MODEL_ANALYSIS
209 if (devId == nullptr || devId->length >= DEVICE_ID_MAX_LEN) {
210 return BAD_PARAM;
211 }
212 auto cppCallBack = [callback](const OHOS::Security::SecurityGuard::SecurityModelResult &tmp) {
213 ::SecurityModelResult result{};
214 FillingRequestResult(tmp, &result);
215 callback(&result);
216 };
217 return OHOS::Security::SecurityGuard::RequestSecurityModelResultAsync(CovertDevId(devId), modelId, "", cppCallBack);
218 #else
219 return 0;
220 #endif
221 }
222
223 #ifdef __cplusplus
224 }
225 #endif
226