• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #ifdef HDC_SUPPORT_ENCRYPT_TCP
16 #include "hdc_ssl.h"
17 
18 namespace Hdc {
HdcSSLBase(SSLInfoPtr hSSLInfo)19 HdcSSLBase::HdcSSLBase(SSLInfoPtr hSSLInfo)
20 {
21 #if OPENSSL_VERSION_NUMBER >= 0x10100003L
22     if (OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, NULL) == 0) {
23         WRITE_LOG(LOG_FATAL, "OPENSSL_init_ssl");
24     }
25     ERR_clear_error();
26 #else
27     SSL_library_init();
28     OpenSSL_add_all_algorithms();
29     SSL_load_error_strings();
30 #endif
31     cipher = hSSLInfo->cipher;
32     sessionId = hSSLInfo->sessionId;
33     isDaemon = hSSLInfo->isDaemon;
34     if (memset_s(preSharedKey, sizeof(preSharedKey), 0, sizeof(preSharedKey)) != EOK) {
35         WRITE_LOG(LOG_FATAL, "memset_s preSharedKey failed");
36     }
37 }
38 
~HdcSSLBase()39 HdcSSLBase::~HdcSSLBase()
40 {
41     if (!isInited) {
42         return;
43     }
44     BIO_reset(outBIO);
45     BIO_reset(inBIO);
46     SSL_free(ssl);
47     inBIO = nullptr;
48     outBIO = nullptr;
49     ssl = nullptr;
50     SSL_CTX_free(sslCtx);
51     sslCtx = nullptr;
52     WRITE_LOG(LOG_INFO, "SSL free finished for sid:%u", sessionId);
53     isInited = false;
54 }
55 
SetSSLInfo(SSLInfoPtr hSSLInfo,HSession hSession)56 void HdcSSLBase::SetSSLInfo(SSLInfoPtr hSSLInfo, HSession hSession)
57 {
58     hSSLInfo->cipher = TLS_AES_128_GCM_SHA256;
59     hSSLInfo->isDaemon = !hSession->serverOrDaemon;
60     hSSLInfo->sessionId = hSession->sessionId;
61 }
62 
InitSSL()63 int HdcSSLBase::InitSSL()
64 {
65     const SSL_METHOD *method = SetSSLMethod();
66     sslCtx = SSL_CTX_new(method);
67     if (sslCtx == nullptr) {
68         WRITE_LOG(LOG_FATAL, "SSL_CTX_new failed");
69         return ERR_GENERIC;
70     }
71     SetPskCallback();
72     SSL_CTX_set_ciphersuites(sslCtx, cipher.c_str());
73     inBIO = BIO_new(BIO_s_mem());
74     outBIO = BIO_new(BIO_s_mem());
75     if (inBIO == nullptr || outBIO == nullptr) {
76         WRITE_LOG(LOG_FATAL, "BIO_new failed");
77         return ERR_GENERIC;
78     }
79     ssl = SSL_new(sslCtx);
80     if (ssl == nullptr) {
81         WRITE_LOG(LOG_FATAL, "SSL_new failed");
82         return ERR_GENERIC;
83     }
84     SetSSLState();
85     SSL_set_bio(ssl, inBIO, outBIO);
86     isInited = true;
87     WRITE_LOG(LOG_DEBUG, "SSL init finished for sid:%u", sessionId);
88     return RET_SUCCESS;
89 }
90 
DoSSLWrite(const int bufLen,uint8_t * bufPtr)91 int HdcSSLBase::DoSSLWrite(const int bufLen, uint8_t *bufPtr)
92 {
93     int retSSL = SSL_write(ssl, bufPtr, bufLen);
94     if (retSSL < 0) {
95         WRITE_LOG(LOG_WARN, "SSL write error, ret:%d", retSSL);
96         int err = SSL_get_error(ssl, retSSL);
97         if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
98             DEBUG_LOG("SSL write error, ret:%d, err:%d, retry");
99         } else {
100             WRITE_LOG(LOG_WARN, "SSL write error, ret:%d, err:%d", retSSL, err);
101         }
102     }
103     return retSSL;
104 }
105 
Encrypt(const int bufLen,uint8_t * bufPtr)106 int HdcSSLBase::Encrypt(const int bufLen, uint8_t *bufPtr)
107 {
108     int retSSL = DoSSLWrite(bufLen, bufPtr);
109     if (retSSL < 0) {
110         return retSSL;
111     }
112     int sslBufLen = GetOutPending();
113     int retBIO = DoBIORead(bufPtr, sslBufLen);
114     return retBIO;
115 }
116 
DoSSLRead(const int bufLen,int & index,uint8_t * bufPtr)117 int HdcSSLBase::DoSSLRead(const int bufLen, int &index, uint8_t *bufPtr)
118 {
119     int nSSLRead = SSL_read(ssl, bufPtr + index, std::min(static_cast<int>(BUF_SIZE_DEFAULT16), bufLen - index));
120     if (nSSLRead < 0) {
121         int err = SSL_get_error(ssl, nSSLRead);
122         if (err == SSL_ERROR_WANT_READ) {
123             DEBUG_LOG("SSL_ERROR_WANT_READ, availTailIndex: %d", index);
124             if (index > static_cast<int>(BUF_SIZE_DEFAULT16)) {
125                 return RET_SUCCESS;
126             }
127             return ERR_DECRYPT_WANT_READ;
128         }
129         WRITE_LOG(LOG_FATAL, "nSSLRead is failed errno: %d", err);
130         return ERR_GENERIC;
131     } else {
132         index += nSSLRead;
133         return GetInPending();
134     }
135 }
136 
DoBIOWrite(uint8_t * bufPtr,const int nread) const137 int HdcSSLBase::DoBIOWrite(uint8_t *bufPtr, const int nread) const
138 {
139     return BIO_write(inBIO, bufPtr, nread);
140 }
141 
DoBIORead(uint8_t * bufPtr,const int bufLen) const142 int HdcSSLBase::DoBIORead(uint8_t *bufPtr, const int bufLen) const
143 {
144     return BIO_read(outBIO, bufPtr, bufLen);
145 }
146 
147 
IsHandshakeFinish() const148 bool HdcSSLBase::IsHandshakeFinish() const
149 {
150     // return 0 if handshake is finished, else return 1
151     return (SSL_is_init_finished(ssl) != 0);
152 }
153 
DoHandshake()154 int HdcSSLBase::DoHandshake()
155 {
156     int ret = SSL_do_handshake(ssl);
157     if (ret < 0) {
158         int err = SSL_get_error(ssl, ret);
159         if (err != SSL_ERROR_WANT_READ) {
160             WRITE_LOG(LOG_DEBUG, "SSL_do_handshake error ret is %d ", err);
161         }
162     }
163     return ret;
164 }
165 
166 // use with BIO_read and SSL_write
GetOutPending() const167 int HdcSSLBase::GetOutPending() const
168 {
169     return BIO_pending(outBIO);
170 }
171 
172 // use with BIO_write and SSL_read
GetInPending() const173 int HdcSSLBase::GetInPending() const
174 {
175     return BIO_pending(inBIO);
176 }
177 
ShowSSLInfo()178 void HdcSSLBase::ShowSSLInfo()
179 {
180     WRITE_LOG(LOG_INFO, "SSL handshake status is %d, version is %s, cipher is %s",
181         SSL_is_init_finished(ssl), SSL_get_version(ssl), SSL_get_cipher_name(ssl));
182 }
183 
ClearPsk()184 bool HdcSSLBase::ClearPsk()
185 {
186     // NOTE: 32 is the max length of psk
187     if (memset_s(preSharedKey, sizeof(preSharedKey), 0, sizeof(preSharedKey)) != EOK) {
188         WRITE_LOG(LOG_FATAL, "ClearPsk memset_s failed");
189         return false;
190     }
191     return true;
192 }
193 
InputPsk(unsigned char * psk,int pskLen)194 bool HdcSSLBase::InputPsk(unsigned char *psk, int pskLen)
195 {
196     if (pskLen > BUF_SIZE_PSK) {
197         WRITE_LOG(LOG_FATAL, "pskLen is too long, pskLen = %d", pskLen);
198         return false;
199     }
200     if (memcpy_s(preSharedKey, sizeof(preSharedKey), psk, pskLen) != EOK) {
201         WRITE_LOG(LOG_FATAL, "memcpy_s failed");
202         return false;
203     }
204     return true;
205 }
206 
GenPsk()207 bool HdcSSLBase::GenPsk()
208 {
209     unsigned char* buf = preSharedKey;
210     if (RAND_priv_bytes(buf, BUF_SIZE_PSK) != 1) {
211         WRITE_LOG(LOG_FATAL, "RAND_priv_bytes failed");
212         return false;
213     }
214     return true;
215 }
216 
GetPskEncrypt(unsigned char * bufPtr,const int bufLen,const string & pubkey)217 int HdcSSLBase::GetPskEncrypt(unsigned char *bufPtr, const int bufLen, const string &pubkey)
218 {
219     if (bufLen < BUF_SIZE_PSK_ENCRYPTED) {
220         WRITE_LOG(LOG_FATAL, "bufLen is too short, bufLen = %d", bufLen);
221         return ERR_GENERIC;
222     }
223     unsigned char* buf = preSharedKey;
224     int payloadSize = RsaPubkeyEncrypt(buf, BUF_SIZE_PSK, bufPtr, bufLen, pubkey);
225     WRITE_LOG(LOG_INFO, "RsaPubkeyEncrypt payloadSize = %d, sid: %u", payloadSize, sessionId);
226     return payloadSize; // return the size of encrypted psk
227 }
228 
Decrypt(const int nread,const int bufLen,uint8_t * bufPtr,int & index)229 int HdcSSLBase::Decrypt(const int nread, const int bufLen, uint8_t *bufPtr, int &index)
230 {
231     // the bufPtr need at least BUF_SIZE_DEFAULT16
232     if (!SSL_is_init_finished(ssl)) {
233         WRITE_LOG(LOG_WARN, "SSL is not init finished");
234         return ERR_GENERIC;
235     }
236     int left = nread;
237     int retBio = DoBIOWrite(bufPtr + index, nread); // write to "in"
238     if (retBio < 0) {
239         WRITE_LOG(LOG_WARN, "BIO write failed, ret is %d", retBio);
240         return ERR_GENERIC;
241     }
242     while (left > 0) {
243         int result = DoSSLRead(bufLen, index, bufPtr); // read from ssl, output to bufPtr
244         if (result < 0) {
245             return result;
246         }
247         left = result;
248         DEBUG_LOG("nread=%d, retbio=%d, sslread = %d, left = %d", nread, retBio, index, left);
249     }
250     return RET_SUCCESS;
251 }
252 
PskServerCallback(SSL * ssl,const char * identity,unsigned char * psk,unsigned int maxPskLen)253 unsigned int HdcSSLBase::PskServerCallback(SSL *ssl, const char *identity, unsigned char *psk, unsigned int maxPskLen)
254 {
255     SSL_CTX *sslctx = SSL_get_SSL_CTX(ssl);
256     void *exData = SSL_CTX_get_ex_data(sslctx, 0);
257     if (exData == nullptr) {
258         WRITE_LOG(LOG_FATAL, "exData is null");
259         return 0;
260     }
261     unsigned char *pskInput = reinterpret_cast<unsigned char*>(exData);
262     if (strcmp(identity, STR_PSK_IDENTITY.c_str()) != 0) {
263         WRITE_LOG(LOG_FATAL, "identity not same");
264         return 0;
265     }
266     unsigned int keyLen = BUF_SIZE_PSK;
267     if (keyLen > maxPskLen) {
268         WRITE_LOG(LOG_FATAL, "Server PSK key length invalid, maxpsklen = %d, keyLen = %d", maxPskLen, keyLen);
269         return 0;
270     }
271     if (memcpy_s(psk, maxPskLen, pskInput, keyLen) != EOK) {
272         WRITE_LOG(LOG_FATAL, "memcpy failed, maxpsklen = %d, keyLen = %d", maxPskLen, keyLen);
273         return 0;
274     }
275     return keyLen;
276 }
277 
PskClientCallback(SSL * ssl,const char * hint,char * identity,unsigned int maxIdentityLen,unsigned char * psk,unsigned int maxPskLen)278 unsigned int HdcSSLBase::PskClientCallback(SSL *ssl, const char *hint, char *identity, unsigned int maxIdentityLen,
279     unsigned char *psk, unsigned int maxPskLen)
280 {
281     SSL_CTX *sslctx = SSL_get_SSL_CTX(ssl);
282     void *exData = SSL_CTX_get_ex_data(sslctx, 0);
283     if (exData == nullptr) {
284         WRITE_LOG(LOG_FATAL, "exData is null");
285         return 0;
286     }
287     unsigned char *pskInput = reinterpret_cast<unsigned char*>(exData);
288     if (STR_PSK_IDENTITY.size() + 1 > maxIdentityLen) {
289         WRITE_LOG(LOG_FATAL, "Client identity buffer too small, maxIdentityLen = %d", maxIdentityLen);
290         return 0;
291     }
292     if (strcpy_s(identity, maxIdentityLen, "Client_identity") != EOK) {
293         WRITE_LOG(LOG_FATAL, "Client PSK key strcpy_s identity failed, maxIdentityLen is %u", maxIdentityLen);
294         return 0;
295     }
296     unsigned int keyLen = BUF_SIZE_PSK;
297     if (keyLen > maxPskLen) {
298         WRITE_LOG(LOG_FATAL, "Client PSK key length invalid, maxpsklen = %d, keyLen = %d", maxPskLen, keyLen);
299         return 0;
300     }
301     if (memcpy_s(psk, maxPskLen, pskInput, keyLen) != EOK) {
302         WRITE_LOG(LOG_INFO, "memcpy failed, maxpsklen = %d, keyLen = %d", maxPskLen, keyLen);
303         return 0;
304     }
305 
306     return keyLen;
307 }
308 
RsaPrikeyDecrypt(const unsigned char * inBuf,int inLen,unsigned char * outBuf,int outBufLen)309 int HdcSSLBase::RsaPrikeyDecrypt(const unsigned char *inBuf, int inLen, unsigned char *outBuf, int outBufLen)
310 {
311     int outLen = -1;
312 #ifdef HDC_HOST
313     outLen = HdcAuth::RsaPrikeyDecryptPsk(inBuf, inLen, outBuf, outBufLen);
314 #endif
315     return outLen;
316 }
317 
RsaPubkeyEncrypt(const unsigned char * inBuf,int inLen,unsigned char * outBuf,int outBufSize,const string & pubkey)318 int HdcSSLBase::RsaPubkeyEncrypt(const unsigned char *inBuf, int inLen,
319     unsigned char *outBuf, int outBufSize, const string &pubkey)
320 {
321     int outLen = -1;
322 #ifndef HDC_HOST
323     outLen = HdcAuth::RsaPubkeyEncryptPsk(inBuf, inLen, outBuf, outBufSize, pubkey);
324 #endif
325     return outLen;
326 }
327 
PerformHandshake(vector<uint8_t> & outBuf)328 int HdcSSLBase::PerformHandshake(vector<uint8_t> &outBuf)
329 {
330     if (IsHandshakeFinish()) {
331         return RET_SSL_HANDSHAKE_FINISHED;
332     }
333     DoHandshake();
334     int nread = GetOutPending();
335     if (nread <= 0) {
336         WRITE_LOG(LOG_WARN, "SSL PerformHandshake failed, nread = %d", nread);
337         return ERR_GENERIC;
338     }
339     outBuf.resize(nread);
340     int outLen = DoBIORead(outBuf.data(), nread);
341     if (outLen < 0) {
342         WRITE_LOG(LOG_WARN, "BIO_read failed");
343         return ERR_GENERIC;
344     }
345     while (outLen < nread) {
346         int tempLen = DoBIORead(outBuf.data() + outLen, nread - outLen);
347         if (tempLen > 0) {
348             outLen += tempLen;
349             WRITE_LOG(LOG_WARN, "PerformHandshake BIO_read left data size %d", tempLen);
350         } else if (tempLen == 0) {
351             break;
352         } else {
353             WRITE_LOG(LOG_FATAL, "DoBIORead failed");
354             return ERR_GENERIC;
355         }
356     }
357     return RET_SUCCESS;
358 }
359 
SetHandshakeLabel(HSession hSession)360 bool HdcSSLBase::SetHandshakeLabel(HSession hSession)
361 {
362     if (!IsHandshakeFinish()) {
363         return false;
364     }
365     ShowSSLInfo();
366     hSession->sslHandshake = true;
367     return true;
368 }
369 } // namespace Hdc
370 #endif // HDC_SUPPORT_ENCRYPT_TCP