• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2007-2008 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /**
18  * Native glue for Java class org.apache.harmony.xnet.provider.jsse.NativeCrypto
19  */
20 
21 #define LOG_TAG "NativeCrypto"
22 
23 #include <fcntl.h>
24 #include <sys/socket.h>
25 #include <unistd.h>
26 
27 #include <jni.h>
28 
29 #include <openssl/dsa.h>
30 #include <openssl/err.h>
31 #include <openssl/evp.h>
32 #include <openssl/rand.h>
33 #include <openssl/rsa.h>
34 #include <openssl/ssl.h>
35 
36 #include "AsynchronousSocketCloseMonitor.h"
37 #include "JNIHelp.h"
38 #include "JniConstants.h"
39 #include "JniException.h"
40 #include "LocalArray.h"
41 #include "NetFd.h"
42 #include "NetworkUtilities.h"
43 #include "ScopedLocalRef.h"
44 #include "ScopedPrimitiveArray.h"
45 #include "ScopedUtfChars.h"
46 #include "UniquePtr.h"
47 
48 #undef WITH_JNI_TRACE
49 #ifdef WITH_JNI_TRACE
50 #define JNI_TRACE(...) \
51         ((void)LOG(LOG_INFO, LOG_TAG "-jni", __VA_ARGS__));     \
52 /*
53         ((void)printf("I/" LOG_TAG "-jni:"));         \
54         ((void)printf(__VA_ARGS__));          \
55         ((void)printf("\n"))
56 */
57 #else
58 #define JNI_TRACE(...) ((void)0)
59 #endif
60 
61 struct BIO_Delete {
operator ()BIO_Delete62     void operator()(BIO* p) const {
63         BIO_free(p);
64     }
65 };
66 typedef UniquePtr<BIO, BIO_Delete> Unique_BIO;
67 
68 struct BIGNUM_Delete {
operator ()BIGNUM_Delete69     void operator()(BIGNUM* p) const {
70         BN_free(p);
71     }
72 };
73 typedef UniquePtr<BIGNUM, BIGNUM_Delete> Unique_BIGNUM;
74 
75 struct DH_Delete {
operator ()DH_Delete76     void operator()(DH* p) const {
77         DH_free(p);
78     }
79 };
80 typedef UniquePtr<DH, DH_Delete> Unique_DH;
81 
82 struct DSA_Delete {
operator ()DSA_Delete83     void operator()(DSA* p) const {
84         DSA_free(p);
85     }
86 };
87 typedef UniquePtr<DSA, DSA_Delete> Unique_DSA;
88 
89 struct EVP_PKEY_Delete {
operator ()EVP_PKEY_Delete90     void operator()(EVP_PKEY* p) const {
91         EVP_PKEY_free(p);
92     }
93 };
94 typedef UniquePtr<EVP_PKEY, EVP_PKEY_Delete> Unique_EVP_PKEY;
95 
96 struct PKCS8_PRIV_KEY_INFO_Delete {
operator ()PKCS8_PRIV_KEY_INFO_Delete97     void operator()(PKCS8_PRIV_KEY_INFO* p) const {
98         PKCS8_PRIV_KEY_INFO_free(p);
99     }
100 };
101 typedef UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> Unique_PKCS8_PRIV_KEY_INFO;
102 
103 struct RSA_Delete {
operator ()RSA_Delete104     void operator()(RSA* p) const {
105         RSA_free(p);
106     }
107 };
108 typedef UniquePtr<RSA, RSA_Delete> Unique_RSA;
109 
110 struct SSL_Delete {
operator ()SSL_Delete111     void operator()(SSL* p) const {
112         SSL_free(p);
113     }
114 };
115 typedef UniquePtr<SSL, SSL_Delete> Unique_SSL;
116 
117 struct SSL_CTX_Delete {
operator ()SSL_CTX_Delete118     void operator()(SSL_CTX* p) const {
119         SSL_CTX_free(p);
120     }
121 };
122 typedef UniquePtr<SSL_CTX, SSL_CTX_Delete> Unique_SSL_CTX;
123 
124 struct X509_Delete {
operator ()X509_Delete125     void operator()(X509* p) const {
126         X509_free(p);
127     }
128 };
129 typedef UniquePtr<X509, X509_Delete> Unique_X509;
130 
131 struct X509_NAME_Delete {
operator ()X509_NAME_Delete132     void operator()(X509_NAME* p) const {
133         X509_NAME_free(p);
134     }
135 };
136 typedef UniquePtr<X509_NAME, X509_NAME_Delete> Unique_X509_NAME;
137 
138 struct sk_SSL_CIPHER_Delete {
operator ()sk_SSL_CIPHER_Delete139     void operator()(STACK_OF(SSL_CIPHER)* p) const {
140         sk_SSL_CIPHER_free(p);
141     }
142 };
143 typedef UniquePtr<STACK_OF(SSL_CIPHER), sk_SSL_CIPHER_Delete> Unique_sk_SSL_CIPHER;
144 
145 struct sk_X509_Delete {
operator ()sk_X509_Delete146     void operator()(STACK_OF(X509)* p) const {
147         sk_X509_free(p);
148     }
149 };
150 typedef UniquePtr<STACK_OF(X509), sk_X509_Delete> Unique_sk_X509;
151 
152 struct sk_X509_NAME_Delete {
operator ()sk_X509_NAME_Delete153     void operator()(STACK_OF(X509_NAME)* p) const {
154         sk_X509_NAME_free(p);
155     }
156 };
157 typedef UniquePtr<STACK_OF(X509_NAME), sk_X509_NAME_Delete> Unique_sk_X509_NAME;
158 
159 /**
160  * Frees the SSL error state.
161  *
162  * OpenSSL keeps an "error stack" per thread, and given that this code
163  * can be called from arbitrary threads that we don't keep track of,
164  * we err on the side of freeing the error state promptly (instead of,
165  * say, at thread death).
166  */
freeSslErrorState(void)167 static void freeSslErrorState(void) {
168     ERR_clear_error();
169     ERR_remove_state(0);
170 }
171 
172 /*
173  * Checks this thread's OpenSSL error queue and throws a RuntimeException if
174  * necessary.
175  *
176  * @return 1 if an exception was thrown, 0 if not.
177  */
throwExceptionIfNecessary(JNIEnv * env,const char * location)178 static int throwExceptionIfNecessary(JNIEnv* env, const char* location  __attribute__ ((unused))) {
179     int error = ERR_get_error();
180     int result = 0;
181 
182     if (error != 0) {
183         char message[256];
184         ERR_error_string_n(error, message, sizeof(message));
185         JNI_TRACE("OpenSSL error in %s %d: %s", location, error, message);
186         jniThrowRuntimeException(env, message);
187         result = 1;
188     }
189 
190     freeSslErrorState();
191     return result;
192 }
193 
194 /**
195  * Throws an SocketTimeoutException with the given string as a message.
196  */
throwSocketTimeoutException(JNIEnv * env,const char * message)197 static void throwSocketTimeoutException(JNIEnv* env, const char* message) {
198     JNI_TRACE("throwSocketTimeoutException %s", message);
199     jniThrowException(env, "java/net/SocketTimeoutException", message);
200 }
201 
202 /**
203  * Throws a javax.net.ssl.SSLException with the given string as a message.
204  */
throwSSLExceptionStr(JNIEnv * env,const char * message)205 static void throwSSLExceptionStr(JNIEnv* env, const char* message) {
206     JNI_TRACE("throwSSLExceptionStr %s", message);
207     jniThrowException(env, "javax/net/ssl/SSLException", message);
208 }
209 
210 /**
211  * Throws a javax.net.ssl.SSLProcotolException with the given string as a message.
212  */
throwSSLProtocolExceptionStr(JNIEnv * env,const char * message)213 static void throwSSLProtocolExceptionStr(JNIEnv* env, const char* message) {
214     JNI_TRACE("throwSSLProtocolExceptionStr %s", message);
215     jniThrowException(env, "javax/net/ssl/SSLProtocolException", message);
216 }
217 
218 /**
219  * Throws an SSLException with a message constructed from the current
220  * SSL errors. This will also log the errors.
221  *
222  * @param env the JNI environment
223  * @param ssl the possibly NULL SSL
224  * @param sslErrorCode error code returned from SSL_get_error() or
225  * SSL_ERROR_NONE to probe with ERR_get_error
226  * @param message null-ok; general error message
227  */
throwSSLExceptionWithSslErrors(JNIEnv * env,SSL * ssl,int sslErrorCode,const char * message)228 static void throwSSLExceptionWithSslErrors(
229         JNIEnv* env, SSL* ssl, int sslErrorCode, const char* message) {
230 
231     if (message == NULL) {
232         message = "SSL error";
233     }
234 
235     // First consult the SSL error code for the general message.
236     const char* sslErrorStr = NULL;
237     switch (sslErrorCode) {
238         case SSL_ERROR_NONE:
239             if (ERR_peek_error() == 0) {
240                 sslErrorStr = "OK";
241             } else {
242                 sslErrorStr = "";
243             }
244             break;
245         case SSL_ERROR_SSL:
246             sslErrorStr = "Failure in SSL library, usually a protocol error";
247             break;
248         case SSL_ERROR_WANT_READ:
249             sslErrorStr = "SSL_ERROR_WANT_READ occurred. You should never see this.";
250             break;
251         case SSL_ERROR_WANT_WRITE:
252             sslErrorStr = "SSL_ERROR_WANT_WRITE occurred. You should never see this.";
253             break;
254         case SSL_ERROR_WANT_X509_LOOKUP:
255             sslErrorStr = "SSL_ERROR_WANT_X509_LOOKUP occurred. You should never see this.";
256             break;
257         case SSL_ERROR_SYSCALL:
258             sslErrorStr = "I/O error during system call";
259             break;
260         case SSL_ERROR_ZERO_RETURN:
261             sslErrorStr = "SSL_ERROR_ZERO_RETURN occurred. You should never see this.";
262             break;
263         case SSL_ERROR_WANT_CONNECT:
264             sslErrorStr = "SSL_ERROR_WANT_CONNECT occurred. You should never see this.";
265             break;
266         case SSL_ERROR_WANT_ACCEPT:
267             sslErrorStr = "SSL_ERROR_WANT_ACCEPT occurred. You should never see this.";
268             break;
269         default:
270             sslErrorStr = "Unknown SSL error";
271     }
272 
273     // Prepend either our explicit message or a default one.
274     char* str;
275     if (asprintf(&str, "%s: ssl=%p: %s", message, ssl, sslErrorStr) <= 0) {
276         // problem with asprintf, just throw argument message, log everything
277         throwSSLExceptionStr(env, message);
278         LOGV("%s: ssl=%p: %s", message, ssl, sslErrorStr);
279         freeSslErrorState();
280         return;
281     }
282 
283     char* allocStr = str;
284 
285     // For protocol errors, SSL might have more information.
286     if (sslErrorCode == SSL_ERROR_NONE || sslErrorCode == SSL_ERROR_SSL) {
287         // Append each error as an additional line to the message.
288         for (;;) {
289             char errStr[256];
290             const char* file;
291             int line;
292             const char* data;
293             int flags;
294             unsigned long err = ERR_get_error_line_data(&file, &line, &data, &flags);
295             if (err == 0) {
296                 break;
297             }
298 
299             ERR_error_string_n(err, errStr, sizeof(errStr));
300 
301             int ret = asprintf(&str, "%s\n%s (%s:%d %p:0x%08x)",
302                                (allocStr == NULL) ? "" : allocStr,
303                                errStr,
304                                file,
305                                line,
306                                (flags & ERR_TXT_STRING) ? data : "(no data)",
307                                flags);
308 
309             if (ret < 0) {
310                 break;
311             }
312 
313             free(allocStr);
314             allocStr = str;
315         }
316     // For errors during system calls, errno might be our friend.
317     } else if (sslErrorCode == SSL_ERROR_SYSCALL) {
318         if (asprintf(&str, "%s, %s", allocStr, strerror(errno)) >= 0) {
319             free(allocStr);
320             allocStr = str;
321         }
322     // If the error code is invalid, print it.
323     } else if (sslErrorCode > SSL_ERROR_WANT_ACCEPT) {
324         if (asprintf(&str, ", error code is %d", sslErrorCode) >= 0) {
325             free(allocStr);
326             allocStr = str;
327         }
328     }
329 
330     if (sslErrorCode == SSL_ERROR_SSL) {
331         throwSSLProtocolExceptionStr(env, allocStr);
332     } else {
333         throwSSLExceptionStr(env, allocStr);
334     }
335 
336     LOGV("%s", allocStr);
337     free(allocStr);
338     freeSslErrorState();
339 }
340 
341 /**
342  * Helper function that grabs the casts an ssl pointer and then checks for nullness.
343  * If this function returns NULL and <code>throwIfNull</code> is
344  * passed as <code>true</code>, then this function will call
345  * <code>throwSSLExceptionStr</code> before returning, so in this case of
346  * NULL, a caller of this function should simply return and allow JNI
347  * to do its thing.
348  *
349  * @param env the JNI environment
350  * @param ssl_address; the ssl_address pointer as an integer
351  * @param throwIfNull whether to throw if the SSL pointer is NULL
352  * @returns the pointer, which may be NULL
353  */
to_SSL_CTX(JNIEnv * env,int ssl_ctx_address,bool throwIfNull)354 static SSL_CTX* to_SSL_CTX(JNIEnv* env, int ssl_ctx_address, bool throwIfNull) {
355     SSL_CTX* ssl_ctx = reinterpret_cast<SSL_CTX*>(static_cast<uintptr_t>(ssl_ctx_address));
356     if ((ssl_ctx == NULL) && throwIfNull) {
357         JNI_TRACE("ssl_ctx == null");
358         jniThrowNullPointerException(env, "ssl_ctx == null");
359     }
360     return ssl_ctx;
361 }
362 
to_SSL(JNIEnv * env,int ssl_address,bool throwIfNull)363 static SSL* to_SSL(JNIEnv* env, int ssl_address, bool throwIfNull) {
364     SSL* ssl = reinterpret_cast<SSL*>(static_cast<uintptr_t>(ssl_address));
365     if ((ssl == NULL) && throwIfNull) {
366         JNI_TRACE("ssl == null");
367         jniThrowNullPointerException(env, "ssl == null");
368     }
369     return ssl;
370 }
371 
to_SSL_SESSION(JNIEnv * env,int ssl_session_address,bool throwIfNull)372 static SSL_SESSION* to_SSL_SESSION(JNIEnv* env, int ssl_session_address, bool throwIfNull) {
373     SSL_SESSION* ssl_session
374         = reinterpret_cast<SSL_SESSION*>(static_cast<uintptr_t>(ssl_session_address));
375     if ((ssl_session == NULL) && throwIfNull) {
376         JNI_TRACE("ssl_session == null");
377         jniThrowNullPointerException(env, "ssl_session == null");
378     }
379     return ssl_session;
380 }
381 
382 /**
383  * Converts a Java byte[] to an OpenSSL BIGNUM, allocating the BIGNUM on the
384  * fly.
385  */
arrayToBignum(JNIEnv * env,jbyteArray source)386 static BIGNUM* arrayToBignum(JNIEnv* env, jbyteArray source) {
387     JNI_TRACE("arrayToBignum(%p)", source);
388 
389     ScopedByteArrayRO sourceBytes(env, source);
390     if (sourceBytes.get() == NULL) {
391         JNI_TRACE("arrayToBignum(%p) => NULL", source);
392         return NULL;
393     }
394     BIGNUM* bn = BN_bin2bn(reinterpret_cast<const unsigned char*>(sourceBytes.get()),
395                            sourceBytes.size(),
396                            NULL);
397     JNI_TRACE("arrayToBignum(%p) => %p", source, bn);
398     return bn;
399 }
400 
401 /**
402  * OpenSSL locking support. Taken from the O'Reilly book by Viega et al., but I
403  * suppose there are not many other ways to do this on a Linux system (modulo
404  * isomorphism).
405  */
406 #define MUTEX_TYPE pthread_mutex_t
407 #define MUTEX_SETUP(x) pthread_mutex_init(&(x), NULL)
408 #define MUTEX_CLEANUP(x) pthread_mutex_destroy(&(x))
409 #define MUTEX_LOCK(x) pthread_mutex_lock(&(x))
410 #define MUTEX_UNLOCK(x) pthread_mutex_unlock(&(x))
411 #define THREAD_ID pthread_self()
412 #define THROW_EXCEPTION (-2)
413 #define THROW_SOCKETTIMEOUTEXCEPTION (-3)
414 #define THROWN_SOCKETEXCEPTION (-4)
415 
416 static MUTEX_TYPE* mutex_buf = NULL;
417 
locking_function(int mode,int n,const char *,int)418 static void locking_function(int mode, int n, const char*, int) {
419     if (mode & CRYPTO_LOCK) {
420         MUTEX_LOCK(mutex_buf[n]);
421     } else {
422         MUTEX_UNLOCK(mutex_buf[n]);
423     }
424 }
425 
id_function(void)426 static unsigned long id_function(void) {
427     return ((unsigned long)THREAD_ID);
428 }
429 
THREAD_setup(void)430 int THREAD_setup(void) {
431     mutex_buf = new MUTEX_TYPE[CRYPTO_num_locks()];
432     if (!mutex_buf) {
433         return 0;
434     }
435 
436     for (int i = 0; i < CRYPTO_num_locks(); ++i) {
437         MUTEX_SETUP(mutex_buf[i]);
438     }
439 
440     CRYPTO_set_id_callback(id_function);
441     CRYPTO_set_locking_callback(locking_function);
442 
443     return 1;
444 }
445 
THREAD_cleanup(void)446 int THREAD_cleanup(void) {
447     if (!mutex_buf) {
448         return 0;
449     }
450 
451     CRYPTO_set_id_callback(NULL);
452     CRYPTO_set_locking_callback(NULL);
453 
454     for (int i = 0; i < CRYPTO_num_locks( ); i++) {
455         MUTEX_CLEANUP(mutex_buf[i]);
456     }
457 
458     free(mutex_buf);
459     mutex_buf = NULL;
460 
461     return 1;
462 }
463 
464 /**
465  * Initialization phase for every OpenSSL job: Loads the Error strings, the
466  * crypto algorithms and reset the OpenSSL library
467  */
NativeCrypto_clinit(JNIEnv *,jclass)468 static void NativeCrypto_clinit(JNIEnv*, jclass)
469 {
470     SSL_load_error_strings();
471     ERR_load_crypto_strings();
472     SSL_library_init();
473     OpenSSL_add_all_algorithms();
474     THREAD_setup();
475 }
476 
477 /**
478  * public static native int EVP_PKEY_new_DSA(byte[] p, byte[] q, byte[] g,
479  *                                           byte[] pub_key, byte[] priv_key);
480  */
NativeCrypto_EVP_PKEY_new_DSA(JNIEnv * env,jclass,jbyteArray p,jbyteArray q,jbyteArray g,jbyteArray pub_key,jbyteArray priv_key)481 static EVP_PKEY* NativeCrypto_EVP_PKEY_new_DSA(JNIEnv* env, jclass,
482                                                jbyteArray p, jbyteArray q, jbyteArray g,
483                                                jbyteArray pub_key, jbyteArray priv_key) {
484     JNI_TRACE("EVP_PKEY_new_DSA(p=%p, q=%p, g=%p, pub_key=%p, priv_key=%p)",
485               p, q, g, pub_key, priv_key);
486 
487     Unique_DSA dsa(DSA_new());
488     if (dsa.get() == NULL) {
489         jniThrowRuntimeException(env, "DSA_new failed");
490         return NULL;
491     }
492 
493     dsa->p = arrayToBignum(env, p);
494     dsa->q = arrayToBignum(env, q);
495     dsa->g = arrayToBignum(env, g);
496     dsa->pub_key = arrayToBignum(env, pub_key);
497 
498     if (priv_key != NULL) {
499         dsa->priv_key = arrayToBignum(env, priv_key);
500     }
501 
502     if (dsa->p == NULL || dsa->q == NULL || dsa->g == NULL || dsa->pub_key == NULL) {
503         jniThrowRuntimeException(env, "Unable to convert BigInteger to BIGNUM");
504         return NULL;
505     }
506 
507     Unique_EVP_PKEY pkey(EVP_PKEY_new());
508     if (pkey.get() == NULL) {
509         jniThrowRuntimeException(env, "EVP_PKEY_new failed");
510         return NULL;
511     }
512     if (EVP_PKEY_assign_DSA(pkey.get(), dsa.get()) != 1) {
513         jniThrowRuntimeException(env, "EVP_PKEY_assign_DSA failed");
514         return NULL;
515     }
516     dsa.release();
517     JNI_TRACE("EVP_PKEY_new_DSA(p=%p, q=%p, g=%p, pub_key=%p, priv_key=%p) => %p",
518               p, q, g, pub_key, priv_key, pkey.get());
519     return pkey.release();
520 }
521 
522 /**
523  * private static native int EVP_PKEY_new_RSA(byte[] n, byte[] e, byte[] d, byte[] p, byte[] q);
524  */
NativeCrypto_EVP_PKEY_new_RSA(JNIEnv * env,jclass,jbyteArray n,jbyteArray e,jbyteArray d,jbyteArray p,jbyteArray q)525 static EVP_PKEY* NativeCrypto_EVP_PKEY_new_RSA(JNIEnv* env, jclass,
526                                                jbyteArray n, jbyteArray e, jbyteArray d,
527                                                jbyteArray p, jbyteArray q) {
528     JNI_TRACE("EVP_PKEY_new_RSA(n=%p, e=%p, d=%p, p=%p, q=%p)", n, e, d, p, q);
529 
530     Unique_RSA rsa(RSA_new());
531     if (rsa.get() == NULL) {
532         jniThrowRuntimeException(env, "RSA_new failed");
533         return NULL;
534     }
535 
536     rsa->n = arrayToBignum(env, n);
537     rsa->e = arrayToBignum(env, e);
538 
539     if (d != NULL) {
540         rsa->d = arrayToBignum(env, d);
541     }
542 
543     if (p != NULL) {
544         rsa->p = arrayToBignum(env, p);
545     }
546 
547     if (q != NULL) {
548         rsa->q = arrayToBignum(env, q);
549     }
550 
551 #ifdef WITH_JNI_TRACE
552     if (p != NULL && q != NULL) {
553         int check = RSA_check_key(rsa.get());
554         JNI_TRACE("EVP_PKEY_new_RSA(...) RSA_check_key returns %d", check);
555     }
556 #endif
557 
558     if (rsa->n == NULL || rsa->e == NULL) {
559         jniThrowRuntimeException(env, "Unable to convert BigInteger to BIGNUM");
560         return NULL;
561     }
562 
563     Unique_EVP_PKEY pkey(EVP_PKEY_new());
564     if (pkey.get() == NULL) {
565         jniThrowRuntimeException(env, "EVP_PKEY_new failed");
566         return NULL;
567     }
568     if (EVP_PKEY_assign_RSA(pkey.get(), rsa.get()) != 1) {
569         jniThrowRuntimeException(env, "EVP_PKEY_new failed");
570         return NULL;
571     }
572     rsa.release();
573     JNI_TRACE("EVP_PKEY_new_RSA(n=%p, e=%p, d=%p, p=%p, q=%p) => %p", n, e, d, p, q, pkey.get());
574     return pkey.release();
575 }
576 
577 /**
578  * private static native void EVP_PKEY_free(int pkey);
579  */
NativeCrypto_EVP_PKEY_free(JNIEnv *,jclass,EVP_PKEY * pkey)580 static void NativeCrypto_EVP_PKEY_free(JNIEnv*, jclass, EVP_PKEY* pkey) {
581     JNI_TRACE("EVP_PKEY_free(%p)", pkey);
582 
583     if (pkey != NULL) {
584         EVP_PKEY_free(pkey);
585     }
586 }
587 
588 /*
589  * public static native int EVP_MD_CTX_create()
590  */
NativeCrypto_EVP_MD_CTX_create(JNIEnv * env,jclass)591 static jint NativeCrypto_EVP_MD_CTX_create(JNIEnv* env, jclass) {
592     JNI_TRACE("NativeCrypto_EVP_MD_CTX_create");
593 
594     EVP_MD_CTX* ctx = EVP_MD_CTX_create();
595     if (ctx == NULL) {
596         jniThrowOutOfMemoryError(env, "Unable to allocate EVP_MD_CTX");
597     }
598     JNI_TRACE("NativeCrypto_EVP_MD_CTX_create => %p", ctx);
599     return (jint) ctx;
600 
601 }
602 
603 /*
604  * public static native void EVP_MD_CTX_destroy(int)
605  */
NativeCrypto_EVP_MD_CTX_destroy(JNIEnv *,jclass,EVP_MD_CTX * ctx)606 static void NativeCrypto_EVP_MD_CTX_destroy(JNIEnv*, jclass, EVP_MD_CTX* ctx) {
607     JNI_TRACE("NativeCrypto_EVP_MD_CTX_destroy(%p)", ctx);
608 
609     if (ctx != NULL) {
610         EVP_MD_CTX_destroy(ctx);
611     }
612 }
613 
614 /*
615  * public static native int EVP_MD_CTX_copy(int)
616  */
NativeCrypto_EVP_MD_CTX_copy(JNIEnv * env,jclass,EVP_MD_CTX * ctx)617 static jint NativeCrypto_EVP_MD_CTX_copy(JNIEnv* env, jclass, EVP_MD_CTX* ctx) {
618     JNI_TRACE("NativeCrypto_EVP_MD_CTX_copy(%p)", ctx);
619 
620     if (ctx == NULL) {
621         jniThrowNullPointerException(env, NULL);
622         return NULL;
623     }
624     EVP_MD_CTX* copy = EVP_MD_CTX_create();
625     if (copy == NULL) {
626         jniThrowOutOfMemoryError(env, "Unable to allocate copy of EVP_MD_CTX");
627         return NULL;
628     }
629     EVP_MD_CTX_init(copy);
630     int result = EVP_MD_CTX_copy_ex(copy, ctx);
631     if (result == 0) {
632         EVP_MD_CTX_destroy(copy);
633         jniThrowRuntimeException(env, "Unable to copy EVP_MD_CTX");
634         return NULL;
635     }
636     JNI_TRACE("NativeCrypto_EVP_MD_CTX_copy(%p) => %p", ctx, copy);
637     return (jint) copy;
638 }
639 
640 /*
641  * public static native int EVP_DigestFinal(int, byte[], int)
642  */
NativeCrypto_EVP_DigestFinal(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jbyteArray hash,jint offset)643 static jint NativeCrypto_EVP_DigestFinal(JNIEnv* env, jclass, EVP_MD_CTX* ctx,
644                                          jbyteArray hash, jint offset) {
645     JNI_TRACE("NativeCrypto_EVP_DigestFinal(%p, %p, %d)", ctx, hash, offset);
646 
647     if (ctx == NULL || hash == NULL) {
648         jniThrowNullPointerException(env, NULL);
649         return -1;
650     }
651 
652     int result = -1;
653 
654     ScopedByteArrayRW hashBytes(env, hash);
655     if (hashBytes.get() == NULL) {
656         return -1;
657     }
658     EVP_DigestFinal(ctx,
659                     reinterpret_cast<unsigned char*>(hashBytes.get() + offset),
660                     reinterpret_cast<unsigned int*>(&result));
661 
662     throwExceptionIfNecessary(env, "NativeCrypto_EVP_DigestFinal");
663 
664     JNI_TRACE("NativeCrypto_EVP_DigestFinal(%p, %p, %d) => %d", ctx, hash, offset, result);
665     return result;
666 }
667 
668 /*
669  * public static native void EVP_DigestInit(int, java.lang.String)
670  */
NativeCrypto_EVP_DigestInit(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jstring algorithm)671 static void NativeCrypto_EVP_DigestInit(JNIEnv* env, jclass, EVP_MD_CTX* ctx, jstring algorithm) {
672     JNI_TRACE("NativeCrypto_EVP_DigestInit(%p, %p)", ctx, algorithm);
673 
674     if (ctx == NULL || algorithm == NULL) {
675         jniThrowNullPointerException(env, NULL);
676         return;
677     }
678 
679     ScopedUtfChars algorithmChars(env, algorithm);
680     if (algorithmChars.c_str() == NULL) {
681         return;
682     }
683 
684     JNI_TRACE("NativeCrypto_EVP_DigestInit(%p, %s)", ctx, algorithmChars.c_str());
685     const EVP_MD* digest = EVP_get_digestbynid(OBJ_txt2nid(algorithmChars.c_str()));
686 
687     if (digest == NULL) {
688         jniThrowRuntimeException(env, "Hash algorithm not found");
689         return;
690     }
691 
692     EVP_DigestInit(ctx, digest);
693 
694     throwExceptionIfNecessary(env, "NativeCrypto_EVP_DigestInit");
695 }
696 
697 /*
698  * public static native int EVP_MD_CTX_size(int)
699  */
NativeCrypto_EVP_MD_CTX_size(JNIEnv * env,jclass,EVP_MD_CTX * ctx)700 static jint NativeCrypto_EVP_MD_CTX_size(JNIEnv* env, jclass, EVP_MD_CTX* ctx) {
701     JNI_TRACE("NativeCrypto_EVP_MD_CTX_size(%p)", ctx);
702 
703     if (ctx == NULL) {
704         jniThrowNullPointerException(env, NULL);
705         return -1;
706     }
707 
708     int result = EVP_MD_CTX_size(ctx);
709 
710     throwExceptionIfNecessary(env, "NativeCrypto_EVP_MD_CTX_size");
711 
712     JNI_TRACE("NativeCrypto_EVP_MD_CTX_size(%p) => %d", ctx, result);
713     return result;
714 }
715 
716 /*
717  * public static int void EVP_MD_CTX_block_size(int)
718  */
NativeCrypto_EVP_MD_CTX_block_size(JNIEnv * env,jclass,EVP_MD_CTX * ctx)719 static jint NativeCrypto_EVP_MD_CTX_block_size(JNIEnv* env, jclass, EVP_MD_CTX* ctx) {
720     JNI_TRACE("NativeCrypto_EVP_MD_CTX_block_size(%p)", ctx);
721 
722     if (ctx == NULL) {
723         jniThrowNullPointerException(env, NULL);
724         return -1;
725     }
726 
727     int result = EVP_MD_CTX_block_size(ctx);
728 
729     throwExceptionIfNecessary(env, "NativeCrypto_EVP_MD_CTX_block_size");
730 
731     JNI_TRACE("NativeCrypto_EVP_MD_CTX_block_size(%p) => %d", ctx, result);
732     return result;
733 }
734 
735 /*
736  * public static native void EVP_DigestUpdate(int, byte[], int, int)
737  */
NativeCrypto_EVP_DigestUpdate(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jbyteArray buffer,jint offset,jint length)738 static void NativeCrypto_EVP_DigestUpdate(JNIEnv* env, jclass, EVP_MD_CTX* ctx,
739                                           jbyteArray buffer, jint offset, jint length) {
740     JNI_TRACE("NativeCrypto_EVP_DigestUpdate(%p, %p, %d, %d)", ctx, buffer, offset, length);
741 
742     if (offset < 0 || length < 0) {
743         jniThrowException(env, "java/lang/IndexOutOfBoundsException", NULL);
744         return;
745     }
746 
747     if (ctx == NULL || buffer == NULL) {
748         jniThrowNullPointerException(env, NULL);
749         return;
750     }
751 
752     ScopedByteArrayRO bufferBytes(env, buffer);
753     if (bufferBytes.get() == NULL) {
754         return;
755     }
756     EVP_DigestUpdate(ctx,
757                      reinterpret_cast<const unsigned char*>(bufferBytes.get() + offset),
758                      length);
759 
760     throwExceptionIfNecessary(env, "NativeCrypto_EVP_DigestUpdate");
761 }
762 
763 /*
764  * public static native void EVP_VerifyInit(int, java.lang.String)
765  */
NativeCrypto_EVP_VerifyInit(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jstring algorithm)766 static void NativeCrypto_EVP_VerifyInit(JNIEnv* env, jclass, EVP_MD_CTX* ctx, jstring algorithm) {
767     JNI_TRACE("NativeCrypto_EVP_VerifyInit(%p, %p)", ctx, algorithm);
768 
769     if (ctx == NULL || algorithm == NULL) {
770         jniThrowNullPointerException(env, NULL);
771         return;
772     }
773 
774     ScopedUtfChars algorithmChars(env, algorithm);
775     if (algorithmChars.c_str() == NULL) {
776         return;
777     }
778 
779     JNI_TRACE("NativeCrypto_EVP_VerifyInit(%p, %s)", ctx, algorithmChars.c_str());
780     const EVP_MD* digest = EVP_get_digestbynid(OBJ_txt2nid(algorithmChars.c_str()));
781 
782     if (digest == NULL) {
783         jniThrowRuntimeException(env, "Hash algorithm not found");
784         return;
785     }
786 
787     EVP_VerifyInit(ctx, digest);
788 
789     throwExceptionIfNecessary(env, "NativeCrypto_EVP_VerifyInit");
790 }
791 
792 /*
793  * public static native void EVP_VerifyUpdate(int, byte[], int, int)
794  */
NativeCrypto_EVP_VerifyUpdate(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jbyteArray buffer,jint offset,jint length)795 static void NativeCrypto_EVP_VerifyUpdate(JNIEnv* env, jclass, EVP_MD_CTX* ctx,
796                                           jbyteArray buffer, jint offset, jint length) {
797     JNI_TRACE("NativeCrypto_EVP_VerifyUpdate(%p, %p, %d, %d)", ctx, buffer, offset, length);
798 
799     if (ctx == NULL || buffer == NULL) {
800         jniThrowNullPointerException(env, NULL);
801         return;
802     }
803 
804     ScopedByteArrayRO bufferBytes(env, buffer);
805     if (bufferBytes.get() == NULL) {
806         return;
807     }
808     EVP_VerifyUpdate(ctx,
809                      reinterpret_cast<const unsigned char*>(bufferBytes.get() + offset),
810                      length);
811 
812     throwExceptionIfNecessary(env, "NativeCrypto_EVP_VerifyUpdate");
813 }
814 
815 /*
816  * public static native int EVP_VerifyFinal(int, byte[], int, int, int)
817  */
NativeCrypto_EVP_VerifyFinal(JNIEnv * env,jclass,EVP_MD_CTX * ctx,jbyteArray buffer,jint offset,jint length,EVP_PKEY * pkey)818 static int NativeCrypto_EVP_VerifyFinal(JNIEnv* env, jclass, EVP_MD_CTX* ctx, jbyteArray buffer,
819                                         jint offset, jint length, EVP_PKEY* pkey) {
820     JNI_TRACE("NativeCrypto_EVP_VerifyFinal(%p, %p, %d, %d, %p)",
821               ctx, buffer, offset, length, pkey);
822 
823     if (ctx == NULL || buffer == NULL || pkey == NULL) {
824         jniThrowNullPointerException(env, NULL);
825         return -1;
826     }
827 
828     ScopedByteArrayRO bufferBytes(env, buffer);
829     if (bufferBytes.get() == NULL) {
830         return -1;
831     }
832     int result = EVP_VerifyFinal(ctx,
833                                  reinterpret_cast<const unsigned char*>(bufferBytes.get() + offset),
834                                  length,
835                                  pkey);
836 
837     throwExceptionIfNecessary(env, "NativeCrypto_EVP_VerifyFinal");
838 
839     JNI_TRACE("NativeCrypto_EVP_VerifyFinal(%p, %p, %d, %d, %p) => %d",
840               ctx, buffer, offset, length, pkey, result);
841 
842     return result;
843 }
844 
845 /**
846  * Helper function that creates an RSA public key from two buffers containing
847  * the big-endian bit representation of the modulus and the public exponent.
848  *
849  * @param mod The data of the modulus
850  * @param modLen The length of the modulus data
851  * @param exp The data of the exponent
852  * @param expLen The length of the exponent data
853  *
854  * @return A pointer to the new RSA structure, or NULL on error
855  */
rsaCreateKey(const jbyte * mod,int modLen,const jbyte * exp,int expLen)856 static RSA* rsaCreateKey(const jbyte* mod, int modLen, const jbyte* exp, int expLen) {
857     JNI_TRACE("rsaCreateKey(..., %d, ..., %d)", modLen, expLen);
858 
859     Unique_RSA rsa(RSA_new());
860     if (rsa.get() == NULL) {
861         return NULL;
862     }
863 
864     rsa->n = BN_bin2bn(reinterpret_cast<const unsigned char*>(mod), modLen, NULL);
865     rsa->e = BN_bin2bn(reinterpret_cast<const unsigned char*>(exp), expLen, NULL);
866 
867     if (rsa->n == NULL || rsa->e == NULL) {
868         return NULL;
869     }
870 
871     JNI_TRACE("rsaCreateKey(..., %d, ..., %d) => %p", modLen, expLen, rsa.get());
872     return rsa.release();
873 }
874 
875 /**
876  * Helper function that verifies a given RSA signature for a given message.
877  *
878  * @param msg The message to verify
879  * @param msgLen The length of the message
880  * @param sig The signature to verify
881  * @param sigLen The length of the signature
882  * @param algorithm The name of the hash/sign algorithm to use, e.g. "RSA-SHA1"
883  * @param rsa The RSA public key to use
884  *
885  * @return 1 on success, 0 on failure, -1 on error (check SSL errors then)
886  *
887  */
rsaVerify(const jbyte * msg,unsigned int msgLen,const jbyte * sig,unsigned int sigLen,const char * algorithm,RSA * rsa)888 static int rsaVerify(const jbyte* msg, unsigned int msgLen, const jbyte* sig,
889                      unsigned int sigLen, const char* algorithm, RSA* rsa) {
890 
891     JNI_TRACE("rsaVerify(%p, %d, %p, %d, %s, %p)",
892               msg, msgLen, sig, sigLen, algorithm, rsa);
893 
894     Unique_EVP_PKEY pkey(EVP_PKEY_new());
895     if (pkey.get() == NULL) {
896         return -1;
897     }
898     EVP_PKEY_set1_RSA(pkey.get(), rsa);
899 
900     const EVP_MD* type = EVP_get_digestbyname(algorithm);
901     if (type == NULL) {
902         return -1;
903     }
904 
905     EVP_MD_CTX ctx;
906     EVP_MD_CTX_init(&ctx);
907     if (EVP_VerifyInit_ex(&ctx, type, NULL) == 0) {
908         return -1;
909     }
910 
911     EVP_VerifyUpdate(&ctx, msg, msgLen);
912     int result = EVP_VerifyFinal(&ctx, reinterpret_cast<const unsigned char*>(sig), sigLen,
913             pkey.get());
914     EVP_MD_CTX_cleanup(&ctx);
915 
916     JNI_TRACE("rsaVerify(%p, %d, %p, %d, %s, %p) => %d",
917               msg, msgLen, sig, sigLen, algorithm, rsa, result);
918     return result;
919 }
920 
921 /**
922  * Verifies an RSA signature.
923  */
NativeCrypto_verifySignature(JNIEnv * env,jclass,jbyteArray msg,jbyteArray sig,jstring algorithm,jbyteArray mod,jbyteArray exp)924 static int NativeCrypto_verifySignature(JNIEnv* env, jclass,
925         jbyteArray msg, jbyteArray sig, jstring algorithm, jbyteArray mod, jbyteArray exp) {
926 
927     JNI_TRACE("NativeCrypto_verifySignature msg=%p sig=%p algorithm=%p mod=%p exp%p",
928               msg, sig, algorithm, mod, exp);
929 
930     ScopedByteArrayRO msgBytes(env, msg);
931     if (msgBytes.get() == NULL) {
932         return -1;
933     }
934     ScopedByteArrayRO sigBytes(env, sig);
935     if (sigBytes.get() == NULL) {
936         return -1;
937     }
938     ScopedByteArrayRO modBytes(env, mod);
939     if (modBytes.get() == NULL) {
940         return -1;
941     }
942     ScopedByteArrayRO expBytes(env, exp);
943     if (expBytes.get() == NULL) {
944         return -1;
945     }
946 
947     ScopedUtfChars algorithmChars(env, algorithm);
948     if (algorithmChars.c_str() == NULL) {
949         return -1;
950     }
951     JNI_TRACE("NativeCrypto_verifySignature algorithmChars=%s", algorithmChars.c_str());
952 
953     Unique_RSA rsa(rsaCreateKey(modBytes.get(), modBytes.size(), expBytes.get(), expBytes.size()));
954     int result = -1;
955     if (rsa.get() != NULL) {
956         result = rsaVerify(msgBytes.get(), msgBytes.size(), sigBytes.get(), sigBytes.size(),
957                 algorithmChars.c_str(), rsa.get());
958     }
959 
960     if (result == -1) {
961         if (!throwExceptionIfNecessary(env, "NativeCrypto_verifySignature")) {
962             jniThrowRuntimeException(env, "Internal error during verification");
963         }
964     }
965 
966     JNI_TRACE("NativeCrypto_verifySignature => %d", result);
967     return result;
968 }
969 
NativeCrypto_RAND_seed(JNIEnv * env,jclass,jbyteArray seed)970 static void NativeCrypto_RAND_seed(JNIEnv* env, jclass, jbyteArray seed) {
971     JNI_TRACE("NativeCrypto_RAND_seed seed=%p", seed);
972     ScopedByteArrayRO randseed(env, seed);
973     if (randseed.get() == NULL) {
974         return;
975     }
976     RAND_seed(randseed.get(), randseed.size());
977 }
978 
NativeCrypto_RAND_load_file(JNIEnv * env,jclass,jstring filename,jlong max_bytes)979 static int NativeCrypto_RAND_load_file(JNIEnv* env, jclass, jstring filename, jlong max_bytes) {
980     JNI_TRACE("NativeCrypto_RAND_load_file filename=%p max_bytes=%lld", filename, max_bytes);
981     ScopedUtfChars file(env, filename);
982     if (file.c_str() == NULL) {
983         return -1;
984     }
985     int result = RAND_load_file(file.c_str(), max_bytes);
986     JNI_TRACE("NativeCrypto_RAND_load_file file=%s => %d", file.c_str(), result);
987     return result;
988 }
989 
990 /**
991  * Convert ssl version constant to string. Based on SSL_get_version
992  */
993 // TODO move to jsse.patch
get_ssl_version(int ssl_version)994 static const char* get_ssl_version(int ssl_version) {
995     switch (ssl_version) {
996         // newest to oldest
997         case TLS1_VERSION: {
998           return SSL_TXT_TLSV1;
999         }
1000         case SSL3_VERSION: {
1001           return SSL_TXT_SSLV3;
1002         }
1003         case SSL2_VERSION: {
1004           return SSL_TXT_SSLV2;
1005         }
1006         default: {
1007           return "unknown";
1008         }
1009     }
1010 }
1011 
1012 #ifdef WITH_JNI_TRACE
1013 /**
1014  * Convert content type constant to string.
1015  */
1016 // TODO move to jsse.patch
get_content_type(int content_type)1017 static const char* get_content_type(int content_type) {
1018     switch (content_type) {
1019         case SSL3_RT_CHANGE_CIPHER_SPEC: {
1020             return "SSL3_RT_CHANGE_CIPHER_SPEC";
1021         }
1022         case SSL3_RT_ALERT: {
1023             return "SSL3_RT_ALERT";
1024         }
1025         case SSL3_RT_HANDSHAKE: {
1026             return "SSL3_RT_HANDSHAKE";
1027         }
1028         case SSL3_RT_APPLICATION_DATA: {
1029             return "SSL3_RT_APPLICATION_DATA";
1030         }
1031         default: {
1032             LOGD("Unknown TLS/SSL content type %d", content_type);
1033             return "<unknown>";
1034         }
1035     }
1036 }
1037 #endif
1038 
1039 #ifdef WITH_JNI_TRACE
1040 /**
1041  * Simple logging call back to show hand shake messages
1042  */
ssl_msg_callback_LOG(int write_p,int ssl_version,int content_type,const void * buf,size_t len,SSL * ssl,void * arg)1043 static void ssl_msg_callback_LOG(int write_p, int ssl_version, int content_type,
1044                                  const void* buf, size_t len, SSL* ssl, void* arg) {
1045   JNI_TRACE("ssl=%p SSL msg %s %s %s %p %d %p",
1046            ssl,
1047            (write_p) ? "send" : "recv",
1048            get_ssl_version(ssl_version),
1049            get_content_type(content_type),
1050            buf,
1051            len,
1052            arg);
1053 }
1054 #endif
1055 
1056 #ifdef WITH_JNI_TRACE
1057 /**
1058  * Based on example logging call back from SSL_CTX_set_info_callback man page
1059  */
info_callback_LOG(const SSL * s,int where,int ret)1060 static void info_callback_LOG(const SSL* s __attribute__ ((unused)), int where, int ret)
1061 {
1062     int w = where & ~SSL_ST_MASK;
1063     const char* str;
1064     if (w & SSL_ST_CONNECT) {
1065         str = "SSL_connect";
1066     } else if (w & SSL_ST_ACCEPT) {
1067         str = "SSL_accept";
1068     } else {
1069         str = "undefined";
1070     }
1071 
1072     if (where & SSL_CB_LOOP) {
1073         JNI_TRACE("ssl=%p %s:%s %s", s, str, SSL_state_string(s), SSL_state_string_long(s));
1074     } else if (where & SSL_CB_ALERT) {
1075         str = (where & SSL_CB_READ) ? "read" : "write";
1076         JNI_TRACE("ssl=%p SSL3 alert %s:%s:%s %s %s",
1077                   s,
1078                   str,
1079                   SSL_alert_type_string(ret),
1080                   SSL_alert_desc_string(ret),
1081                   SSL_alert_type_string_long(ret),
1082                   SSL_alert_desc_string_long(ret));
1083     } else if (where & SSL_CB_EXIT) {
1084         if (ret == 0) {
1085             JNI_TRACE("ssl=%p %s:failed exit in %s %s",
1086                       s, str, SSL_state_string(s), SSL_state_string_long(s));
1087         } else if (ret < 0) {
1088             JNI_TRACE("ssl=%p %s:error exit in %s %s",
1089                       s, str, SSL_state_string(s), SSL_state_string_long(s));
1090         } else if (ret == 1) {
1091             JNI_TRACE("ssl=%p %s:ok exit in %s %s",
1092                       s, str, SSL_state_string(s), SSL_state_string_long(s));
1093         } else {
1094             JNI_TRACE("ssl=%p %s:unknown exit %d in %s %s",
1095                       s, str, ret, SSL_state_string(s), SSL_state_string_long(s));
1096         }
1097     } else if (where & SSL_CB_HANDSHAKE_START) {
1098         JNI_TRACE("ssl=%p handshake start in %s %s",
1099                   s, SSL_state_string(s), SSL_state_string_long(s));
1100     } else if (where & SSL_CB_HANDSHAKE_DONE) {
1101         JNI_TRACE("ssl=%p handshake done in %s %s",
1102                   s, SSL_state_string(s), SSL_state_string_long(s));
1103     } else {
1104         JNI_TRACE("ssl=%p %s:unknown where %d in %s %s",
1105                   s, str, where, SSL_state_string(s), SSL_state_string_long(s));
1106     }
1107 }
1108 #endif
1109 
1110 /**
1111  * Returns an array containing all the X509 certificate's bytes.
1112  */
getCertificateBytes(JNIEnv * env,const STACK_OF (X509)* chain)1113 static jobjectArray getCertificateBytes(JNIEnv* env, const STACK_OF(X509)* chain)
1114 {
1115     if (chain == NULL) {
1116         // Chain can be NULL if the associated cipher doesn't do certs.
1117         return NULL;
1118     }
1119 
1120     int count = sk_X509_num(chain);
1121     if (count <= 0) {
1122         return NULL;
1123     }
1124 
1125     jobjectArray joa = env->NewObjectArray(count, JniConstants::byteArrayClass, NULL);
1126     if (joa == NULL) {
1127         return NULL;
1128     }
1129 
1130     for (int i = 0; i < count; i++) {
1131         X509* cert = sk_X509_value(chain, i);
1132 
1133         int len = i2d_X509(cert, NULL);
1134         if (len < 0) {
1135             return NULL;
1136         }
1137         ScopedLocalRef<jbyteArray> byteArray(env, env->NewByteArray(len));
1138         if (byteArray.get() == NULL) {
1139             return NULL;
1140         }
1141         ScopedByteArrayRW bytes(env, byteArray.get());
1142         if (bytes.get() == NULL) {
1143             return NULL;
1144         }
1145         unsigned char* p = reinterpret_cast<unsigned char*>(bytes.get());
1146         int n = i2d_X509(cert, &p);
1147         if (n < 0) {
1148             return NULL;
1149         }
1150         env->SetObjectArrayElement(joa, i, byteArray.get());
1151     }
1152 
1153     return joa;
1154 }
1155 
1156 /**
1157  * Returns an array containing all the X500 principal's bytes.
1158  */
getPrincipalBytes(JNIEnv * env,const STACK_OF (X509_NAME)* names)1159 static jobjectArray getPrincipalBytes(JNIEnv* env, const STACK_OF(X509_NAME)* names)
1160 {
1161     if (names == NULL) {
1162         return NULL;
1163     }
1164 
1165     int count = sk_X509_NAME_num(names);
1166     if (count <= 0) {
1167         return NULL;
1168     }
1169 
1170     jobjectArray joa = env->NewObjectArray(count, JniConstants::byteArrayClass, NULL);
1171     if (joa == NULL) {
1172         return NULL;
1173     }
1174 
1175     for (int i = 0; i < count; i++) {
1176         X509_NAME* principal = sk_X509_NAME_value(names, i);
1177 
1178         int len = i2d_X509_NAME(principal, NULL);
1179         if (len < 0) {
1180             return NULL;
1181         }
1182         ScopedLocalRef<jbyteArray> byteArray(env, env->NewByteArray(len));
1183         if (byteArray.get() == NULL) {
1184             return NULL;
1185         }
1186         ScopedByteArrayRW bytes(env, byteArray.get());
1187         if (bytes.get() == NULL) {
1188             return NULL;
1189         }
1190         unsigned char* p = reinterpret_cast<unsigned char*>(bytes.get());
1191         int n = i2d_X509_NAME(principal, &p);
1192         if (n < 0) {
1193             return NULL;
1194         }
1195         env->SetObjectArrayElement(joa, i, byteArray.get());
1196     }
1197 
1198     return joa;
1199 }
1200 
1201 /**
1202  * Our additional application data needed for getting synchronization right.
1203  * This maybe warrants a bit of lengthy prose:
1204  *
1205  * (1) We use a flag to reflect whether we consider the SSL connection alive.
1206  * Any read or write attempt loops will be cancelled once this flag becomes 0.
1207  *
1208  * (2) We use an int to count the number of threads that are blocked by the
1209  * underlying socket. This may be at most two (one reader and one writer), since
1210  * the Java layer ensures that no more threads will enter the native code at the
1211  * same time.
1212  *
1213  * (3) The pipe is used primarily as a means of cancelling a blocking select()
1214  * when we want to close the connection (aka "emergency button"). It is also
1215  * necessary for dealing with a possible race condition situation: There might
1216  * be cases where both threads see an SSL_ERROR_WANT_READ or
1217  * SSL_ERROR_WANT_WRITE. Both will enter a select() with the proper argument.
1218  * If one leaves the select() successfully before the other enters it, the
1219  * "success" event is already consumed and the second thread will be blocked,
1220  * possibly forever (depending on network conditions).
1221  *
1222  * The idea for solving the problem looks like this: Whenever a thread is
1223  * successful in moving around data on the network, and it knows there is
1224  * another thread stuck in a select(), it will write a byte to the pipe, waking
1225  * up the other thread. A thread that returned from select(), on the other hand,
1226  * knows whether it's been woken up by the pipe. If so, it will consume the
1227  * byte, and the original state of affairs has been restored.
1228  *
1229  * The pipe may seem like a bit of overhead, but it fits in nicely with the
1230  * other file descriptors of the select(), so there's only one condition to wait
1231  * for.
1232  *
1233  * (4) Finally, a mutex is needed to make sure that at most one thread is in
1234  * either SSL_read() or SSL_write() at any given time. This is an OpenSSL
1235  * requirement. We use the same mutex to guard the field for counting the
1236  * waiting threads.
1237  *
1238  * Note: The current implementation assumes that we don't have to deal with
1239  * problems induced by multiple cores or processors and their respective
1240  * memory caches. One possible problem is that of inconsistent views on the
1241  * "aliveAndKicking" field. This could be worked around by also enclosing all
1242  * accesses to that field inside a lock/unlock sequence of our mutex, but
1243  * currently this seems a bit like overkill. Marking volatile at the very least.
1244  *
1245  * During handshaking, additional fields are used to up-call into
1246  * Java to perform certificate verification and handshake
1247  * completion. These are also used in any renegotiation.
1248  *
1249  * (5) the JNIEnv so we can invoke the Java callback
1250  *
1251  * (6) a NativeCrypto.SSLHandshakeCallbacks instance for callbacks from native to Java
1252  *
1253  * (7) a java.io.FileDescriptor wrapper to check for socket close
1254  *
1255  * Because renegotiation can be requested by the peer at any time,
1256  * care should be taken to maintain an appropriate JNIEnv on any
1257  * downcall to openssl since it could result in an upcall to Java. The
1258  * current code does try to cover these cases by conditionally setting
1259  * the JNIEnv on calls that can read and write to the SSL such as
1260  * SSL_do_handshake, SSL_read, SSL_write, and SSL_shutdown.
1261  *
1262  * Finally, we have one other piece of state setup by OpenSSL callbacks:
1263  *
1264  * (8) a set of ephemeral RSA keys that is lazily generated if a peer
1265  * wants to use an exportable RSA cipher suite.
1266  *
1267  */
1268 class AppData {
1269   public:
1270     volatile int aliveAndKicking;
1271     int waitingThreads;
1272     int fdsEmergency[2];
1273     MUTEX_TYPE mutex;
1274     JNIEnv* env;
1275     jobject sslHandshakeCallbacks;
1276     jobject fileDescriptor;
1277     Unique_RSA ephemeralRsa;
1278 
1279     /**
1280      * Creates the application data context for the SSL*.
1281      */
1282   public:
create()1283     static AppData* create() {
1284         UniquePtr<AppData> appData(new AppData());
1285         if (pipe(appData.get()->fdsEmergency) == -1) {
1286             return NULL;
1287         }
1288         if (!setBlocking(appData.get()->fdsEmergency[0], false)) {
1289             return NULL;
1290         }
1291         if (MUTEX_SETUP(appData.get()->mutex) == -1) {
1292             return NULL;
1293         }
1294         return appData.release();
1295     }
1296 
~AppData()1297     ~AppData() {
1298         aliveAndKicking = 0;
1299         if (fdsEmergency[0] != -1) {
1300             close(fdsEmergency[0]);
1301         }
1302         if (fdsEmergency[1] != -1) {
1303             close(fdsEmergency[1]);
1304         }
1305         MUTEX_CLEANUP(mutex);
1306     }
1307 
1308   private:
AppData()1309     AppData() :
1310             aliveAndKicking(1),
1311             waitingThreads(0),
1312             env(NULL),
1313             sslHandshakeCallbacks(NULL),
1314             ephemeralRsa(NULL) {
1315         fdsEmergency[0] = -1;
1316         fdsEmergency[1] = -1;
1317     }
1318 
1319   public:
1320     /**
1321      * Used to set the SSL-to-Java callback state before each SSL_*
1322      * call that may result in a callback. It should be cleared after
1323      * the operation returns with clearCallbackState.
1324      *
1325      * @param env The JNIEnv
1326      * @param shc The SSLHandshakeCallbacks
1327      * @param fd The FileDescriptor
1328      */
setCallbackState(JNIEnv * e,jobject shc,jobject fd)1329     bool setCallbackState(JNIEnv* e, jobject shc, jobject fd) {
1330         NetFd netFd(e, fd);
1331         if (netFd.isClosed()) {
1332             return false;
1333         }
1334         env = e;
1335         sslHandshakeCallbacks = shc;
1336         fileDescriptor = fd;
1337         return true;
1338     }
1339 
clearCallbackState()1340     void clearCallbackState() {
1341         env = NULL;
1342         sslHandshakeCallbacks = NULL;
1343         fileDescriptor = NULL;
1344     }
1345 
1346 };
1347 
1348 /**
1349  * Dark magic helper function that checks, for a given SSL session, whether it
1350  * can SSL_read() or SSL_write() without blocking. Takes into account any
1351  * concurrent attempts to close the SSLSocket from the Java side. This is
1352  * needed to get rid of the hangs that occur when thread #1 closes the SSLSocket
1353  * while thread #2 is sitting in a blocking read or write. The type argument
1354  * specifies whether we are waiting for readability or writability. It expects
1355  * to be passed either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, since we
1356  * only need to wait in case one of these problems occurs.
1357  *
1358  * @param env
1359  * @param type Either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE
1360  * @param fdObject The FileDescriptor, since appData->fileDescriptor should be NULL
1361  * @param appData The application data structure with mutex info etc.
1362  * @param timeout The timeout value for select call, with the special value
1363  *                0 meaning no timeout at all (wait indefinitely). Note: This is
1364  *                the Java semantics of the timeout value, not the usual
1365  *                select() semantics.
1366  * @return The result of the inner select() call,
1367  * THROW_SOCKETEXCEPTION if a SocketException was thrown, -1 on
1368  * additional errors
1369  */
sslSelect(JNIEnv * env,int type,jobject fdObject,AppData * appData,int timeout)1370 static int sslSelect(JNIEnv* env, int type, jobject fdObject, AppData* appData, int timeout) {
1371     // This loop is an expanded version of the NET_FAILURE_RETRY
1372     // macro. It cannot simply be used in this case because select
1373     // cannot be restarted without recreating the fd_sets and timeout
1374     // structure.
1375     int result;
1376     fd_set rfds;
1377     fd_set wfds;
1378     do {
1379         NetFd fd(env, fdObject);
1380         if (fd.isClosed()) {
1381             result = THROWN_SOCKETEXCEPTION;
1382             break;
1383         }
1384         int intFd = fd.get();
1385         JNI_TRACE("sslSelect type=%s fd=%d appData=%p timeout=%d",
1386                   (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE", intFd, appData, timeout);
1387 
1388         FD_ZERO(&rfds);
1389         FD_ZERO(&wfds);
1390 
1391         if (type == SSL_ERROR_WANT_READ) {
1392             FD_SET(intFd, &rfds);
1393         } else {
1394             FD_SET(intFd, &wfds);
1395         }
1396 
1397         FD_SET(appData->fdsEmergency[0], &rfds);
1398 
1399         int maxFd = (intFd > appData->fdsEmergency[0]) ? intFd : appData->fdsEmergency[0];
1400 
1401         // Build a struct for the timeout data if we actually want a timeout.
1402         timeval tv;
1403         timeval* ptv;
1404         if (timeout > 0) {
1405             tv.tv_sec = timeout / 1000;
1406             tv.tv_usec = 0;
1407             ptv = &tv;
1408         } else {
1409             ptv = NULL;
1410         }
1411 
1412         AsynchronousSocketCloseMonitor monitor(intFd);
1413         result = select(maxFd + 1, &rfds, &wfds, NULL, ptv);
1414         JNI_TRACE("sslSelect %s fd=%d appData=%p timeout=%d => %d",
1415                   (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE",
1416                   fd.get(), appData, timeout, result);
1417         if (result == -1) {
1418             if (fd.isClosed()) {
1419                 result = THROWN_SOCKETEXCEPTION;
1420                 break;
1421             }
1422             if (errno != EINTR) {
1423                 break;
1424             }
1425         }
1426     } while (result == -1);
1427 
1428     if (MUTEX_LOCK(appData->mutex) == -1) {
1429         return -1;
1430     }
1431 
1432     if (result > 0) {
1433         // If we have been woken up by the emergency pipe. We can't be
1434         // sure there is a token in it because it could have been read
1435         // by the thread that wrote it between when when we woke up
1436         // from select and attempt to read it here. Thus we cannot
1437         // safely read it in a blocking way (so we make it
1438         // non-blocking at creation).
1439         if (FD_ISSET(appData->fdsEmergency[0], &rfds)) {
1440             char token;
1441             do {
1442                 read(appData->fdsEmergency[0], &token, 1);
1443             } while (errno == EINTR);
1444         }
1445     }
1446 
1447     // Tell the world that there is now one thread less waiting for the
1448     // underlying network.
1449     appData->waitingThreads--;
1450 
1451     MUTEX_UNLOCK(appData->mutex);
1452 
1453     return result;
1454 }
1455 
1456 /**
1457  * Helper function that wakes up a thread blocked in select(), in case there is
1458  * one. Is being called by sslRead() and sslWrite() as well as by JNI glue
1459  * before closing the connection.
1460  *
1461  * @param data The application data structure with mutex info etc.
1462  */
sslNotify(AppData * appData)1463 static void sslNotify(AppData* appData) {
1464     // Write a byte to the emergency pipe, so a concurrent select() can return.
1465     // Note we have to restore the errno of the original system call, since the
1466     // caller relies on it for generating error messages.
1467     int errnoBackup = errno;
1468     char token = '*';
1469     do {
1470         errno = 0;
1471         write(appData->fdsEmergency[1], &token, 1);
1472     } while (errno == EINTR);
1473     errno = errnoBackup;
1474 }
1475 
1476 // From private header file external/openssl/ssl_locl.h
1477 // TODO move dependent code to jsse.patch to avoid dependency
1478 #define SSL_aRSA                0x00000001L
1479 #define SSL_aDSS                0x00000002L
1480 #define SSL_aNULL               0x00000004L
1481 #define SSL_aDH                 0x00000008L
1482 #define SSL_aECDH               0x00000010L
1483 #define SSL_aKRB5               0x00000020L
1484 #define SSL_aECDSA              0x00000040L
1485 #define SSL_aPSK                0x00000080L
1486 
1487 /**
1488  * Converts an SSL_CIPHER's algorithms field to a TrustManager auth argument
1489  */
1490 // TODO move to jsse.patch
SSL_CIPHER_authentication_method(const SSL_CIPHER * cipher)1491 static const char* SSL_CIPHER_authentication_method(const SSL_CIPHER* cipher)
1492 {
1493     unsigned long alg_auth = cipher->algorithm_auth;
1494 
1495     const char* au;
1496     switch (alg_auth) {
1497         case SSL_aRSA:
1498             au="RSA";
1499             break;
1500         case SSL_aDSS:
1501             au="DSS";
1502             break;
1503         case SSL_aDH:
1504             au="DH";
1505             break;
1506         case SSL_aKRB5:
1507             au="KRB5";
1508             break;
1509         case SSL_aECDH:
1510             au = "ECDH";
1511             break;
1512         case SSL_aNULL:
1513             au="None";
1514             break;
1515         case SSL_aECDSA:
1516             au="ECDSA";
1517             break;
1518         case SSL_aPSK:
1519             au="PSK";
1520             break;
1521         default:
1522             au="unknown";
1523             break;
1524     }
1525     return au;
1526 }
1527 
1528 /**
1529  * Converts an SSL_CIPHER's algorithms field to a TrustManager auth argument
1530  */
1531 // TODO move to jsse.patch
SSL_authentication_method(SSL * ssl)1532 static const char* SSL_authentication_method(SSL* ssl)
1533 {
1534     switch (ssl->version) {
1535       case SSL2_VERSION:
1536         return "RSA";
1537       case SSL3_VERSION:
1538       case TLS1_VERSION:
1539       case DTLS1_VERSION:
1540         return SSL_CIPHER_authentication_method(ssl->s3->tmp.new_cipher);
1541       default:
1542         return "unknown";
1543     }
1544 }
1545 
toAppData(const SSL * ssl)1546 static AppData* toAppData(const SSL* ssl) {
1547     return reinterpret_cast<AppData*>(SSL_get_app_data(ssl));
1548 }
1549 
1550 /**
1551  * Verify the X509 certificate via SSL_CTX_set_cert_verify_callback
1552  */
cert_verify_callback(X509_STORE_CTX * x509_store_ctx,void * arg)1553 static int cert_verify_callback(X509_STORE_CTX* x509_store_ctx, void* arg __attribute__ ((unused)))
1554 {
1555     /* Get the correct index to the SSLobject stored into X509_STORE_CTX. */
1556     SSL* ssl = reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(x509_store_ctx,
1557             SSL_get_ex_data_X509_STORE_CTX_idx()));
1558     JNI_TRACE("ssl=%p cert_verify_callback x509_store_ctx=%p arg=%p", ssl, x509_store_ctx, arg);
1559 
1560     AppData* appData = toAppData(ssl);
1561     JNIEnv* env = appData->env;
1562     if (env == NULL) {
1563         LOGE("AppData->env missing in cert_verify_callback");
1564         JNI_TRACE("ssl=%p cert_verify_callback => 0", ssl);
1565         return 0;
1566     }
1567     jobject sslHandshakeCallbacks = appData->sslHandshakeCallbacks;
1568 
1569     jclass cls = env->GetObjectClass(sslHandshakeCallbacks);
1570     jmethodID methodID
1571         = env->GetMethodID(cls, "verifyCertificateChain", "([[BLjava/lang/String;)V");
1572 
1573     jobjectArray objectArray = getCertificateBytes(env, x509_store_ctx->untrusted);
1574 
1575     const char* authMethod = SSL_authentication_method(ssl);
1576     JNI_TRACE("ssl=%p cert_verify_callback calling verifyCertificateChain authMethod=%s",
1577               ssl, authMethod);
1578     jstring authMethodString = env->NewStringUTF(authMethod);
1579     env->CallVoidMethod(sslHandshakeCallbacks, methodID, objectArray, authMethodString);
1580 
1581     int result = (env->ExceptionCheck()) ? 0 : 1;
1582     JNI_TRACE("ssl=%p cert_verify_callback => %d", ssl, result);
1583     return result;
1584 }
1585 
1586 /**
1587  * Call back to watch for handshake to be completed. This is necessary
1588  * for SSL_MODE_HANDSHAKE_CUTTHROUGH support, since SSL_do_handshake
1589  * returns before the handshake is completed in this case.
1590  */
info_callback(const SSL * ssl,int where,int ret)1591 static void info_callback(const SSL* ssl, int where, int ret __attribute__ ((unused))) {
1592     JNI_TRACE("ssl=%p info_callback where=0x%x ret=%d", ssl, where, ret);
1593 #ifdef WITH_JNI_TRACE
1594     info_callback_LOG(ssl, where, ret);
1595 #endif
1596     if (!(where & SSL_CB_HANDSHAKE_DONE)) {
1597         JNI_TRACE("ssl=%p info_callback ignored", ssl);
1598         return;
1599     }
1600 
1601     AppData* appData = toAppData(ssl);
1602     JNIEnv* env = appData->env;
1603     if (env == NULL) {
1604         LOGE("AppData->env missing in info_callback");
1605         JNI_TRACE("ssl=%p info_callback env error", ssl);
1606         return;
1607     }
1608     if (env->ExceptionCheck()) {
1609         JNI_TRACE("ssl=%p info_callback already pending exception", ssl);
1610         return;
1611     }
1612 
1613     jobject sslHandshakeCallbacks = appData->sslHandshakeCallbacks;
1614 
1615     jclass cls = env->GetObjectClass(sslHandshakeCallbacks);
1616     jmethodID methodID = env->GetMethodID(cls, "handshakeCompleted", "()V");
1617 
1618     JNI_TRACE("ssl=%p info_callback calling handshakeCompleted", ssl);
1619     env->CallVoidMethod(sslHandshakeCallbacks, methodID);
1620 
1621     if (env->ExceptionCheck()) {
1622         JNI_TRACE("ssl=%p info_callback exception", ssl);
1623     }
1624     JNI_TRACE("ssl=%p info_callback completed", ssl);
1625 }
1626 
1627 /**
1628  * Call back to ask for a client certificate
1629  */
client_cert_cb(SSL * ssl,X509 ** x509Out,EVP_PKEY ** pkeyOut)1630 static int client_cert_cb(SSL* ssl, X509** x509Out, EVP_PKEY** pkeyOut) {
1631     JNI_TRACE("ssl=%p client_cert_cb x509Out=%p pkeyOut=%p", ssl, x509Out, pkeyOut);
1632 
1633     AppData* appData = toAppData(ssl);
1634     JNIEnv* env = appData->env;
1635     if (env == NULL) {
1636         LOGE("AppData->env missing in client_cert_cb");
1637         JNI_TRACE("ssl=%p client_cert_cb env error => 0", ssl);
1638         return 0;
1639     }
1640     if (env->ExceptionCheck()) {
1641         JNI_TRACE("ssl=%p client_cert_cb already pending exception", ssl);
1642         return 0;
1643     }
1644     jobject sslHandshakeCallbacks = appData->sslHandshakeCallbacks;
1645 
1646     jclass cls = env->GetObjectClass(sslHandshakeCallbacks);
1647     jmethodID methodID
1648         = env->GetMethodID(cls, "clientCertificateRequested", "([B[[B)V");
1649 
1650     // Call Java callback which can use SSL_use_certificate and SSL_use_PrivateKey to set values
1651     char ssl2_ctype = SSL3_CT_RSA_SIGN;
1652     const char* ctype = NULL;
1653     int ctype_num = 0;
1654     jobjectArray issuers = NULL;
1655     switch (ssl->version) {
1656         case SSL2_VERSION:
1657             ctype = &ssl2_ctype;
1658             ctype_num = 1;
1659             break;
1660         case SSL3_VERSION:
1661         case TLS1_VERSION:
1662         case DTLS1_VERSION:
1663             ctype = ssl->s3->tmp.ctype;
1664             ctype_num = ssl->s3->tmp.ctype_num;
1665             issuers = getPrincipalBytes(env, ssl->s3->tmp.ca_names);
1666             break;
1667     }
1668 #ifdef WITH_JNI_TRACE
1669     for (int i = 0; i < ctype_num; i++) {
1670         JNI_TRACE("ssl=%p clientCertificateRequested keyTypes[%d]=%d", ssl, i, ctype[i]);
1671     }
1672 #endif
1673 
1674     jbyteArray keyTypes = env->NewByteArray(ctype_num);
1675     if (keyTypes == NULL) {
1676         JNI_TRACE("ssl=%p client_cert_cb bytes == null => 0", ssl);
1677         return 0;
1678     }
1679     env->SetByteArrayRegion(keyTypes, 0, ctype_num, reinterpret_cast<const jbyte*>(ctype));
1680 
1681     JNI_TRACE("ssl=%p clientCertificateRequested calling clientCertificateRequested "
1682               "keyTypes=%p issuers=%p", ssl, keyTypes, issuers);
1683     env->CallVoidMethod(sslHandshakeCallbacks, methodID, keyTypes, issuers);
1684 
1685     if (env->ExceptionCheck()) {
1686         JNI_TRACE("ssl=%p client_cert_cb exception => 0", ssl);
1687         return 0;
1688     }
1689 
1690     // Check for values set from Java
1691     X509*     certificate = SSL_get_certificate(ssl);
1692     EVP_PKEY* privatekey  = SSL_get_privatekey(ssl);
1693     int result;
1694     if (certificate != NULL && privatekey != NULL) {
1695         *x509Out = certificate;
1696         *pkeyOut = privatekey;
1697         result = 1;
1698     } else {
1699         *x509Out = NULL;
1700         *pkeyOut = NULL;
1701         result = 0;
1702     }
1703     JNI_TRACE("ssl=%p client_cert_cb => *x509=%p *pkey=%p %d", ssl, *x509Out, *pkeyOut, result);
1704     return result;
1705 }
1706 
rsaGenerateKey(int keylength)1707 static RSA* rsaGenerateKey(int keylength) {
1708     Unique_BIGNUM bn(BN_new());
1709     if (bn.get() == NULL) {
1710         return NULL;
1711     }
1712     int setWordResult = BN_set_word(bn.get(), RSA_F4);
1713     if (setWordResult != 1) {
1714         return NULL;
1715     }
1716     Unique_RSA rsa(RSA_new());
1717     if (rsa.get() == NULL) {
1718         return NULL;
1719     }
1720     int generateResult = RSA_generate_key_ex(rsa.get(), keylength, bn.get(), NULL);
1721     if (generateResult != 1) {
1722         return NULL;
1723     }
1724     return rsa.release();
1725 }
1726 
1727 /**
1728  * Call back to ask for an ephemeral RSA key for SSL_RSA_EXPORT_WITH_RC4_40_MD5 (aka EXP-RC4-MD5)
1729  */
tmp_rsa_callback(SSL * ssl,int is_export,int keylength)1730 static RSA* tmp_rsa_callback(SSL* ssl __attribute__ ((unused)),
1731                              int is_export __attribute__ ((unused)),
1732                              int keylength) {
1733     JNI_TRACE("ssl=%p tmp_rsa_callback is_export=%d keylength=%d", ssl, is_export, keylength);
1734 
1735     AppData* appData = toAppData(ssl);
1736     if (appData->ephemeralRsa.get() == NULL) {
1737         JNI_TRACE("ssl=%p tmp_rsa_callback generating ephemeral RSA key", ssl);
1738         appData->ephemeralRsa.reset(rsaGenerateKey(keylength));
1739     }
1740     JNI_TRACE("ssl=%p tmp_rsa_callback => %p", ssl, appData->ephemeralRsa.get());
1741     return appData->ephemeralRsa.get();
1742 }
1743 
dhGenerateParameters(int keylength)1744 static DH* dhGenerateParameters(int keylength) {
1745 
1746     /*
1747      * The SSL_CTX_set_tmp_dh_callback(3SSL) man page discusses two
1748      * different options for generating DH keys. One is generating the
1749      * keys using a single set of DH parameters. However, generating
1750      * DH parameters is slow enough (minutes) that they suggest doing
1751      * it once at install time. The other is to generate DH keys from
1752      * DSA parameters. Generating DSA parameters is faster than DH
1753      * parameters, but to prevent small subgroup attacks, they needed
1754      * to be regenerated for each set of DH keys. Setting the
1755      * SSL_OP_SINGLE_DH_USE option make sure OpenSSL will call back
1756      * for new DH parameters every type it needs to generate DH keys.
1757      */
1758 #if 0
1759     // Slow path that takes minutes but could be cached
1760     Unique_DH dh(DH_new());
1761     if (!DH_generate_parameters_ex(dh.get(), keylength, 2, NULL)) {
1762         return NULL;
1763     }
1764     return dh.release();
1765 #else
1766     // Faster path but must have SSL_OP_SINGLE_DH_USE set
1767     Unique_DSA dsa(DSA_new());
1768     if (!DSA_generate_parameters_ex(dsa.get(), keylength, NULL, 0, NULL, NULL, NULL)) {
1769         return NULL;
1770     }
1771     DH* dh = DSA_dup_DH(dsa.get());
1772     return dh;
1773 #endif
1774 }
1775 
1776 /**
1777  * Call back to ask for Diffie-Hellman parameters
1778  */
tmp_dh_callback(SSL * ssl,int is_export,int keylength)1779 static DH* tmp_dh_callback(SSL* ssl __attribute__ ((unused)),
1780                            int is_export __attribute__ ((unused)),
1781                            int keylength) {
1782     JNI_TRACE("ssl=%p tmp_dh_callback is_export=%d keylength=%d", ssl, is_export, keylength);
1783     DH* tmp_dh = dhGenerateParameters(keylength);
1784     JNI_TRACE("ssl=%p tmp_dh_callback => %p", ssl, tmp_dh);
1785     return tmp_dh;
1786 }
1787 
1788 /*
1789  * public static native int SSL_CTX_new();
1790  */
NativeCrypto_SSL_CTX_new(JNIEnv * env,jclass)1791 static int NativeCrypto_SSL_CTX_new(JNIEnv* env, jclass) {
1792     Unique_SSL_CTX sslCtx(SSL_CTX_new(SSLv23_method()));
1793     if (sslCtx.get() == NULL) {
1794         jniThrowRuntimeException(env, "SSL_CTX_new");
1795         return NULL;
1796     }
1797     SSL_CTX_set_options(sslCtx.get(),
1798                         SSL_OP_ALL
1799                         // Note: We explicitly do not allow SSLv2 to be used.
1800                         | SSL_OP_NO_SSLv2
1801                         // We also disable session tickets for better compatibility b/2682876
1802                         | SSL_OP_NO_TICKET
1803                         // We also disable compression for better compatibility b/2710492 b/2710497
1804                         | SSL_OP_NO_COMPRESSION
1805                         // Because dhGenerateParameters uses DSA_generate_parameters_ex
1806                         | SSL_OP_SINGLE_DH_USE);
1807 
1808     int mode = SSL_CTX_get_mode(sslCtx.get());
1809     /*
1810      * Turn on "partial write" mode. This means that SSL_write() will
1811      * behave like Posix write() and possibly return after only
1812      * writing a partial buffer. Note: The alternative, perhaps
1813      * surprisingly, is not that SSL_write() always does full writes
1814      * but that it will force you to retry write calls having
1815      * preserved the full state of the original call. (This is icky
1816      * and undesirable.)
1817      */
1818     mode |= SSL_MODE_ENABLE_PARTIAL_WRITE;
1819 #if defined(SSL_MODE_SMALL_BUFFERS) /* not all SSL versions have this */
1820     mode |= SSL_MODE_SMALL_BUFFERS;  /* lazily allocate record buffers; usually saves
1821                                       * 44k over the default */
1822 #endif
1823 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) /* not all SSL versions have this */
1824     mode |= SSL_MODE_HANDSHAKE_CUTTHROUGH;  /* enable sending of client data as soon as
1825                                              * ClientCCS and ClientFinished are sent */
1826 #endif
1827     SSL_CTX_set_mode(sslCtx.get(), mode);
1828 
1829     SSL_CTX_set_cert_verify_callback(sslCtx.get(), cert_verify_callback, NULL);
1830     SSL_CTX_set_info_callback(sslCtx.get(), info_callback);
1831     SSL_CTX_set_client_cert_cb(sslCtx.get(), client_cert_cb);
1832     SSL_CTX_set_tmp_rsa_callback(sslCtx.get(), tmp_rsa_callback);
1833     SSL_CTX_set_tmp_dh_callback(sslCtx.get(), tmp_dh_callback);
1834 
1835 #ifdef WITH_JNI_TRACE
1836     SSL_CTX_set_msg_callback(sslCtx.get(), ssl_msg_callback_LOG); /* enable for message debug */
1837 #endif
1838     JNI_TRACE("NativeCrypto_SSL_CTX_new => %p", sslCtx.get());
1839     return (jint) sslCtx.release();
1840 }
1841 
1842 /**
1843  * public static native void SSL_CTX_free(int ssl_ctx)
1844  */
NativeCrypto_SSL_CTX_free(JNIEnv * env,jclass,jint ssl_ctx_address)1845 static void NativeCrypto_SSL_CTX_free(JNIEnv* env,
1846         jclass, jint ssl_ctx_address)
1847 {
1848     SSL_CTX* ssl_ctx = to_SSL_CTX(env, ssl_ctx_address, true);
1849     JNI_TRACE("ssl_ctx=%p NativeCrypto_SSL_CTX_free", ssl_ctx);
1850     if (ssl_ctx == NULL) {
1851         return;
1852     }
1853     SSL_CTX_free(ssl_ctx);
1854 }
1855 
1856 /**
1857  * public static native int SSL_new(int ssl_ctx) throws SSLException;
1858  */
NativeCrypto_SSL_new(JNIEnv * env,jclass,jint ssl_ctx_address)1859 static jint NativeCrypto_SSL_new(JNIEnv* env, jclass, jint ssl_ctx_address)
1860 {
1861     SSL_CTX* ssl_ctx = to_SSL_CTX(env, ssl_ctx_address, true);
1862     JNI_TRACE("ssl_ctx=%p NativeCrypto_SSL_new", ssl_ctx);
1863     if (ssl_ctx == NULL) {
1864         return NULL;
1865     }
1866     Unique_SSL ssl(SSL_new(ssl_ctx));
1867     if (ssl.get() == NULL) {
1868         throwSSLExceptionWithSslErrors(env, NULL, SSL_ERROR_NONE,
1869                 "Unable to create SSL structure");
1870         JNI_TRACE("ssl_ctx=%p NativeCrypto_SSL_new => NULL", ssl_ctx);
1871         return NULL;
1872     }
1873 
1874     /* Java code in class OpenSSLSocketImpl does the verification. Meaning of
1875      * SSL_VERIFY_NONE flag in client mode: if not using an anonymous cipher
1876      * (by default disabled), the server will send a certificate which will
1877      * be checked. The result of the certificate verification process can be
1878      * checked after the TLS/SSL handshake using the SSL_get_verify_result(3)
1879      * function. The handshake will be continued regardless of the
1880      * verification result.
1881      */
1882     SSL_set_verify(ssl.get(), SSL_VERIFY_NONE, NULL);
1883 
1884     JNI_TRACE("ssl_ctx=%p NativeCrypto_SSL_new => ssl=%p", ssl_ctx, ssl.get());
1885     return (jint) ssl.release();
1886 }
1887 
NativeCrypto_SSL_use_PrivateKey(JNIEnv * env,jclass,jint ssl_address,jbyteArray privatekey)1888 static void NativeCrypto_SSL_use_PrivateKey(JNIEnv* env, jclass,
1889                                             jint ssl_address, jbyteArray privatekey)
1890 {
1891     SSL* ssl = to_SSL(env, ssl_address, true);
1892     JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey privatekey=%p", ssl, privatekey);
1893     if (ssl == NULL) {
1894         return;
1895     }
1896 
1897     ScopedByteArrayRO buf(env, privatekey);
1898     if (buf.get() == NULL) {
1899         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => threw exception", ssl);
1900         return;
1901     }
1902     const unsigned char* tmp = reinterpret_cast<const unsigned char*>(buf.get());
1903     Unique_PKCS8_PRIV_KEY_INFO pkcs8(d2i_PKCS8_PRIV_KEY_INFO(NULL, &tmp, buf.size()));
1904     if (pkcs8.get() == NULL) {
1905         LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
1906         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
1907                                        "Error parsing private key from DER to PKCS8");
1908         SSL_clear(ssl);
1909         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error from DER to PKCS8", ssl);
1910         return;
1911     }
1912 
1913     Unique_EVP_PKEY privatekeyevp(EVP_PKCS82PKEY(pkcs8.get()));
1914     if (privatekeyevp.get() == NULL) {
1915         LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
1916         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
1917                                        "Error creating private key from PKCS8");
1918         SSL_clear(ssl);
1919         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error from PKCS8 to key", ssl);
1920         return;
1921     }
1922 
1923     int ret = SSL_use_PrivateKey(ssl, privatekeyevp.get());
1924     if (ret == 1) {
1925         privatekeyevp.release();
1926     } else {
1927         LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
1928         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting private key");
1929         SSL_clear(ssl);
1930         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => error", ssl);
1931         return;
1932     }
1933 
1934     JNI_TRACE("ssl=%p NativeCrypto_SSL_use_PrivateKey => ok", ssl);
1935 }
1936 
NativeCrypto_SSL_use_certificate(JNIEnv * env,jclass,jint ssl_address,jobjectArray certificates)1937 static void NativeCrypto_SSL_use_certificate(JNIEnv* env, jclass,
1938                                              jint ssl_address, jobjectArray certificates)
1939 {
1940     SSL* ssl = to_SSL(env, ssl_address, true);
1941     JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate certificates=%p", ssl, certificates);
1942     if (ssl == NULL) {
1943         return;
1944     }
1945 
1946     if (certificates == NULL) {
1947         jniThrowNullPointerException(env, "certificates == null");
1948         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => certificates == null", ssl);
1949         return;
1950     }
1951 
1952     int length = env->GetArrayLength(certificates);
1953     if (length == 0) {
1954         jniThrowException(env, "java/lang/IllegalArgumentException", "certificates.length == 0");
1955         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => certificates.length == 0", ssl);
1956         return;
1957     }
1958 
1959     Unique_X509 certificatesX509[length];
1960     for (int i = 0; i < length; i++) {
1961         ScopedLocalRef<jbyteArray> certificate(env,
1962                 reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(certificates, i)));
1963         if (certificate.get() == NULL) {
1964             jniThrowNullPointerException(env, "certificates element == null");
1965             JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => certificates element null", ssl);
1966             return;
1967         }
1968 
1969         ScopedByteArrayRO buf(env, certificate.get());
1970         if (buf.get() == NULL) {
1971             JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => threw exception", ssl);
1972             return;
1973         }
1974         const unsigned char* tmp = reinterpret_cast<const unsigned char*>(buf.get());
1975         certificatesX509[i].reset(d2i_X509(NULL, &tmp, buf.size()));
1976 
1977         if (certificatesX509[i].get() == NULL) {
1978             LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
1979             throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error parsing certificate");
1980             SSL_clear(ssl);
1981             JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => certificates parsing error", ssl);
1982             return;
1983         }
1984     }
1985 
1986     int ret = SSL_use_certificate(ssl, certificatesX509[0].get());
1987     if (ret == 1) {
1988         certificatesX509[0].release();
1989     } else {
1990         LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
1991         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting certificate");
1992         SSL_clear(ssl);
1993         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => SSL_use_certificate error", ssl);
1994         return;
1995     }
1996 
1997     Unique_sk_X509 chain(sk_X509_new_null());
1998     if (chain.get() == NULL) {
1999         jniThrowOutOfMemoryError(env, "Unable to allocate local certificate chain");
2000         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => chain allocation error", ssl);
2001         return;
2002     }
2003     for (int i = 1; i < length; i++) {
2004         if (!sk_X509_push(chain.get(), certificatesX509[i].release())) {
2005             jniThrowOutOfMemoryError(env, "Unable to push certificate");
2006             JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => certificate push error", ssl);
2007             return;
2008         }
2009     }
2010     int chainResult = SSL_use_certificate_chain(ssl, chain.get());
2011     if (chainResult == 0) {
2012         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting certificate chain");
2013         JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => SSL_use_certificate_chain error",
2014                   ssl);
2015         return;
2016     } else {
2017         chain.release();
2018     }
2019 
2020     JNI_TRACE("ssl=%p NativeCrypto_SSL_use_certificate => ok", ssl);
2021 }
2022 
NativeCrypto_SSL_check_private_key(JNIEnv * env,jclass,jint ssl_address)2023 static void NativeCrypto_SSL_check_private_key(JNIEnv* env, jclass, jint ssl_address)
2024 {
2025     SSL* ssl = to_SSL(env, ssl_address, true);
2026     JNI_TRACE("ssl=%p NativeCrypto_SSL_check_private_key", ssl);
2027     if (ssl == NULL) {
2028         return;
2029     }
2030     int ret = SSL_check_private_key(ssl);
2031     if (ret != 1) {
2032         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error checking private key");
2033         SSL_clear(ssl);
2034         JNI_TRACE("ssl=%p NativeCrypto_SSL_check_private_key => error", ssl);
2035         return;
2036     }
2037     JNI_TRACE("ssl=%p NativeCrypto_SSL_check_private_key => ok", ssl);
2038 }
2039 
NativeCrypto_SSL_set_client_CA_list(JNIEnv * env,jclass,jint ssl_address,jobjectArray principals)2040 static void NativeCrypto_SSL_set_client_CA_list(JNIEnv* env, jclass,
2041                                                 jint ssl_address, jobjectArray principals)
2042 {
2043     SSL* ssl = to_SSL(env, ssl_address, true);
2044     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list principals=%p", ssl, principals);
2045     if (ssl == NULL) {
2046         return;
2047     }
2048 
2049     if (principals == NULL) {
2050         jniThrowNullPointerException(env, "principals == null");
2051         JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => principals == null", ssl);
2052         return;
2053     }
2054 
2055     int length = env->GetArrayLength(principals);
2056     if (length == 0) {
2057         jniThrowException(env, "java/lang/IllegalArgumentException", "principals.length == 0");
2058         JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => principals.length == 0", ssl);
2059         return;
2060     }
2061 
2062     Unique_sk_X509_NAME principalsStack(sk_X509_NAME_new_null());
2063     if (principalsStack.get() == NULL) {
2064         jniThrowOutOfMemoryError(env, "Unable to allocate principal stack");
2065         JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => stack allocation error", ssl);
2066         return;
2067     }
2068     for (int i = 0; i < length; i++) {
2069         ScopedLocalRef<jbyteArray> principal(env,
2070                 reinterpret_cast<jbyteArray>(env->GetObjectArrayElement(principals, i)));
2071         if (principal.get() == NULL) {
2072             jniThrowNullPointerException(env, "principals element == null");
2073             JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => principals element null", ssl);
2074             return;
2075         }
2076 
2077         ScopedByteArrayRO buf(env, principal.get());
2078         if (buf.get() == NULL) {
2079             JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => threw exception", ssl);
2080             return;
2081         }
2082         const unsigned char* tmp = reinterpret_cast<const unsigned char*>(buf.get());
2083         Unique_X509_NAME principalX509Name(d2i_X509_NAME(NULL, &tmp, buf.size()));
2084 
2085         if (principalX509Name.get() == NULL) {
2086             LOGE("%s", ERR_error_string(ERR_peek_error(), NULL));
2087             throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error parsing principal");
2088             SSL_clear(ssl);
2089             JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => principals parsing error",
2090                       ssl);
2091             return;
2092         }
2093 
2094         if (!sk_X509_NAME_push(principalsStack.get(), principalX509Name.release())) {
2095             jniThrowOutOfMemoryError(env, "Unable to push principal");
2096             JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => principal push error", ssl);
2097             return;
2098         }
2099     }
2100 
2101     SSL_set_client_CA_list(ssl, principalsStack.release());
2102     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_client_CA_list => ok", ssl);
2103 }
2104 
2105 /**
2106  * public static native long SSL_get_mode(int ssl);
2107  */
NativeCrypto_SSL_get_mode(JNIEnv * env,jclass,jint ssl_address)2108 static jlong NativeCrypto_SSL_get_mode(JNIEnv* env, jclass, jint ssl_address) {
2109     SSL* ssl = to_SSL(env, ssl_address, true);
2110     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_mode", ssl);
2111     if (ssl == NULL) {
2112       return 0;
2113     }
2114     long mode = SSL_get_mode(ssl);
2115     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_mode => 0x%lx", ssl, mode);
2116     return mode;
2117 }
2118 
2119 /**
2120  * public static native long SSL_set_mode(int ssl, long mode);
2121  */
NativeCrypto_SSL_set_mode(JNIEnv * env,jclass,jint ssl_address,jlong mode)2122 static jlong NativeCrypto_SSL_set_mode(JNIEnv* env, jclass,
2123         jint ssl_address, jlong mode) {
2124     SSL* ssl = to_SSL(env, ssl_address, true);
2125     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_mode mode=0x%llx", ssl, mode);
2126     if (ssl == NULL) {
2127       return 0;
2128     }
2129     long result = SSL_set_mode(ssl, mode);
2130     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_mode => 0x%lx", ssl, result);
2131     return result;
2132 }
2133 
2134 /**
2135  * public static native long SSL_clear_mode(int ssl, long mode);
2136  */
NativeCrypto_SSL_clear_mode(JNIEnv * env,jclass,jint ssl_address,jlong mode)2137 static jlong NativeCrypto_SSL_clear_mode(JNIEnv* env, jclass,
2138         jint ssl_address, jlong mode) {
2139     SSL* ssl = to_SSL(env, ssl_address, true);
2140     JNI_TRACE("ssl=%p NativeCrypto_SSL_clear_mode mode=0x%llx", ssl, mode);
2141     if (ssl == NULL) {
2142       return 0;
2143     }
2144     long result = SSL_clear_mode(ssl, mode);
2145     JNI_TRACE("ssl=%p NativeCrypto_SSL_clear_mode => 0x%lx", ssl, result);
2146     return result;
2147 }
2148 
2149 /**
2150  * public static native long SSL_get_options(int ssl);
2151  */
NativeCrypto_SSL_get_options(JNIEnv * env,jclass,jint ssl_address)2152 static jlong NativeCrypto_SSL_get_options(JNIEnv* env, jclass,
2153         jint ssl_address) {
2154     SSL* ssl = to_SSL(env, ssl_address, true);
2155     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_options", ssl);
2156     if (ssl == NULL) {
2157       return 0;
2158     }
2159     long options = SSL_get_options(ssl);
2160     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_options => 0x%lx", ssl, options);
2161     return options;
2162 }
2163 
2164 /**
2165  * public static native long SSL_set_options(int ssl, long options);
2166  */
NativeCrypto_SSL_set_options(JNIEnv * env,jclass,jint ssl_address,jlong options)2167 static jlong NativeCrypto_SSL_set_options(JNIEnv* env, jclass,
2168         jint ssl_address, jlong options) {
2169     SSL* ssl = to_SSL(env, ssl_address, true);
2170     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_options options=0x%llx", ssl, options);
2171     if (ssl == NULL) {
2172       return 0;
2173     }
2174     long result = SSL_set_options(ssl, options);
2175     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_options => 0x%lx", ssl, result);
2176     return result;
2177 }
2178 
2179 /**
2180  * public static native long SSL_clear_options(int ssl, long options);
2181  */
NativeCrypto_SSL_clear_options(JNIEnv * env,jclass,jint ssl_address,jlong options)2182 static jlong NativeCrypto_SSL_clear_options(JNIEnv* env, jclass,
2183         jint ssl_address, jlong options) {
2184     SSL* ssl = to_SSL(env, ssl_address, true);
2185     JNI_TRACE("ssl=%p NativeCrypto_SSL_clear_options options=0x%llx", ssl, options);
2186     if (ssl == NULL) {
2187       return 0;
2188     }
2189     long result = SSL_clear_options(ssl, options);
2190     JNI_TRACE("ssl=%p NativeCrypto_SSL_clear_options => 0x%lx", ssl, result);
2191     return result;
2192 }
2193 
2194 /**
2195  * Sets the ciphers suites that are enabled in the SSL
2196  */
NativeCrypto_SSL_set_cipher_lists(JNIEnv * env,jclass,jint ssl_address,jobjectArray cipherSuites)2197 static void NativeCrypto_SSL_set_cipher_lists(JNIEnv* env, jclass,
2198         jint ssl_address, jobjectArray cipherSuites)
2199 {
2200     SSL* ssl = to_SSL(env, ssl_address, true);
2201     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_cipher_lists cipherSuites=%p", ssl, cipherSuites);
2202     if (ssl == NULL) {
2203         return;
2204     }
2205     if (cipherSuites == NULL) {
2206         jniThrowNullPointerException(env, "cipherSuites == null");
2207         return;
2208     }
2209 
2210     Unique_sk_SSL_CIPHER cipherstack(sk_SSL_CIPHER_new_null());
2211     if (cipherstack.get() == NULL) {
2212         jniThrowRuntimeException(env, "sk_SSL_CIPHER_new_null failed");
2213         return;
2214     }
2215 
2216     const SSL_METHOD* ssl_method = ssl->method;
2217     int num_ciphers = ssl_method->num_ciphers();
2218 
2219     int length = env->GetArrayLength(cipherSuites);
2220     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_cipher_lists length=%d", ssl, length);
2221     for (int i = 0; i < length; i++) {
2222         ScopedLocalRef<jstring> cipherSuite(env,
2223                 reinterpret_cast<jstring>(env->GetObjectArrayElement(cipherSuites, i)));
2224         ScopedUtfChars c(env, cipherSuite.get());
2225         if (c.c_str() == NULL) {
2226             return;
2227         }
2228         JNI_TRACE("ssl=%p NativeCrypto_SSL_set_cipher_lists cipherSuite=%s", ssl, c.c_str());
2229         bool found = false;
2230         for (int j = 0; j < num_ciphers; j++) {
2231             const SSL_CIPHER* cipher = ssl_method->get_cipher(j);
2232             if ((strcmp(c.c_str(), cipher->name) == 0)
2233                     && (strcmp(SSL_CIPHER_get_version(cipher), "SSLv2"))) {
2234                 if (!sk_SSL_CIPHER_push(cipherstack.get(), cipher)) {
2235                     jniThrowOutOfMemoryError(env, "Unable to push cipher");
2236                     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_cipher_lists => cipher push error", ssl);
2237                     return;
2238                 }
2239                 found = true;
2240             }
2241         }
2242         if (!found) {
2243             jniThrowException(env, "java/lang/IllegalArgumentException",
2244                               "Could not find cipher suite.");
2245             return;
2246         }
2247     }
2248 
2249     int rc = SSL_set_cipher_lists(ssl, cipherstack.get());
2250     if (rc == 0) {
2251         freeSslErrorState();
2252         jniThrowException(env, "java/lang/IllegalArgumentException",
2253                           "Illegal cipher suite strings.");
2254     } else {
2255         cipherstack.release();
2256     }
2257 }
2258 
2259 /**
2260  * Sets certificate expectations, especially for server to request client auth
2261  */
NativeCrypto_SSL_set_verify(JNIEnv * env,jclass,jint ssl_address,jint mode)2262 static void NativeCrypto_SSL_set_verify(JNIEnv* env,
2263         jclass, jint ssl_address, jint mode)
2264 {
2265     SSL* ssl = to_SSL(env, ssl_address, true);
2266     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_verify mode=%x", ssl, mode);
2267     if (ssl == NULL) {
2268       return;
2269     }
2270     SSL_set_verify(ssl, (int)mode, NULL);
2271 }
2272 
2273 /**
2274  * Sets the ciphers suites that are enabled in the SSL
2275  */
NativeCrypto_SSL_set_session(JNIEnv * env,jclass,jint ssl_address,jint ssl_session_address)2276 static void NativeCrypto_SSL_set_session(JNIEnv* env, jclass,
2277         jint ssl_address, jint ssl_session_address)
2278 {
2279     SSL* ssl = to_SSL(env, ssl_address, true);
2280     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, false);
2281     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_session ssl_session=%p", ssl, ssl_session);
2282     if (ssl == NULL) {
2283         return;
2284     }
2285 
2286     int ret = SSL_set_session(ssl, ssl_session);
2287     if (ret != 1) {
2288         /*
2289          * Translate the error, and throw if it turns out to be a real
2290          * problem.
2291          */
2292         int sslErrorCode = SSL_get_error(ssl, ret);
2293         if (sslErrorCode != SSL_ERROR_ZERO_RETURN) {
2294             throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "SSL session set");
2295             SSL_clear(ssl);
2296         }
2297     }
2298 }
2299 
2300 /**
2301  * Sets the ciphers suites that are enabled in the SSL
2302  */
NativeCrypto_SSL_set_session_creation_enabled(JNIEnv * env,jclass,jint ssl_address,jboolean creation_enabled)2303 static void NativeCrypto_SSL_set_session_creation_enabled(JNIEnv* env, jclass,
2304         jint ssl_address, jboolean creation_enabled)
2305 {
2306     SSL* ssl = to_SSL(env, ssl_address, true);
2307     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_session_creation_enabled creation_enabled=%d",
2308               ssl, creation_enabled);
2309     if (ssl == NULL) {
2310         return;
2311     }
2312     SSL_set_session_creation_enabled(ssl, creation_enabled);
2313 }
2314 
NativeCrypto_SSL_set_tlsext_host_name(JNIEnv * env,jclass,jint ssl_address,jstring hostname)2315 static void NativeCrypto_SSL_set_tlsext_host_name(JNIEnv* env, jclass,
2316         jint ssl_address, jstring hostname)
2317 {
2318     SSL* ssl = to_SSL(env, ssl_address, true);
2319     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_tlsext_host_name hostname=%p",
2320               ssl, hostname);
2321     if (ssl == NULL) {
2322         return;
2323     }
2324 
2325     ScopedUtfChars hostnameChars(env, hostname);
2326     if (hostnameChars.c_str() == NULL) {
2327         return;
2328     }
2329     JNI_TRACE("NativeCrypto_SSL_set_tlsext_host_name hostnameChars=%s", hostnameChars.c_str());
2330 
2331     int ret = SSL_set_tlsext_host_name(ssl, hostnameChars.c_str());
2332     if (ret != 1) {
2333         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE, "Error setting host name");
2334         SSL_clear(ssl);
2335         JNI_TRACE("ssl=%p NativeCrypto_SSL_set_tlsext_host_name => error", ssl);
2336         return;
2337     }
2338     JNI_TRACE("ssl=%p NativeCrypto_SSL_set_tlsext_host_name => ok", ssl);
2339 }
2340 
NativeCrypto_SSL_get_servername(JNIEnv * env,jclass,jint ssl_address)2341 static jstring NativeCrypto_SSL_get_servername(JNIEnv* env, jclass, jint ssl_address) {
2342     SSL* ssl = to_SSL(env, ssl_address, true);
2343     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_servername", ssl);
2344     if (ssl == NULL) {
2345         return NULL;
2346     }
2347     const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
2348     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_servername => %s", ssl, servername);
2349     return env->NewStringUTF(servername);
2350 }
2351 
2352 /**
2353  * Perform SSL handshake
2354  */
NativeCrypto_SSL_do_handshake(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc,jint timeout,jboolean client_mode)2355 static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
2356     jint ssl_address, jobject fdObject, jobject shc, jint timeout, jboolean client_mode)
2357 {
2358     SSL* ssl = to_SSL(env, ssl_address, true);
2359     JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake fd=%p shc=%p timeout=%d client_mode=%d",
2360               ssl, fdObject, shc, timeout, client_mode);
2361     if (ssl == NULL) {
2362       return 0;
2363     }
2364     if (fdObject == NULL) {
2365         jniThrowNullPointerException(env, "fd == null");
2366         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2367         return 0;
2368     }
2369     if (shc == NULL) {
2370         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
2371         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2372         return 0;
2373     }
2374 
2375     NetFd fd(env, fdObject);
2376     if (fd.isClosed()) {
2377         // SocketException thrown by NetFd.isClosed
2378         SSL_clear(ssl);
2379         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2380         return 0;
2381     }
2382 
2383     int ret = SSL_set_fd(ssl, fd.get());
2384     JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake s=%d", ssl, fd.get());
2385 
2386     if (ret != 1) {
2387         throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_NONE,
2388                                        "Error setting the file descriptor");
2389         SSL_clear(ssl);
2390         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2391         return 0;
2392     }
2393 
2394     /*
2395      * Make socket non-blocking, so SSL_connect SSL_read() and SSL_write() don't hang
2396      * forever and we can use select() to find out if the socket is ready.
2397      */
2398     if (!setBlocking(fd.get(), false)) {
2399         throwSSLExceptionStr(env, "Unable to make socket non blocking");
2400         SSL_clear(ssl);
2401         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2402         return 0;
2403     }
2404 
2405     /*
2406      * Create our special application data.
2407      */
2408     AppData* appData = AppData::create();
2409     if (appData == NULL) {
2410         throwSSLExceptionStr(env, "Unable to create application data");
2411         SSL_clear(ssl);
2412         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2413         return 0;
2414     }
2415     SSL_set_app_data(ssl, reinterpret_cast<char*>(appData));
2416     JNI_TRACE("ssl=%p AppData::create => %p", ssl, appData);
2417 
2418     if (client_mode) {
2419         SSL_set_connect_state(ssl);
2420     } else {
2421         SSL_set_accept_state(ssl);
2422     }
2423 
2424     ret = 0;
2425     while (appData->aliveAndKicking) {
2426         errno = 0;
2427 
2428         if (!appData->setCallbackState(env, shc, fdObject)) {
2429             // SocketException thrown by NetFd.isClosed
2430             SSL_clear(ssl);
2431             JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2432             return 0;
2433         }
2434         ret = SSL_do_handshake(ssl);
2435         appData->clearCallbackState();
2436         // cert_verify_callback threw exception
2437         if (env->ExceptionCheck()) {
2438             SSL_clear(ssl);
2439             JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2440             return 0;
2441         }
2442         // success case
2443         if (ret == 1) {
2444             break;
2445         }
2446         // retry case
2447         if (errno == EINTR) {
2448             continue;
2449         }
2450         // error case
2451         int sslError = SSL_get_error(ssl, ret);
2452         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake ret=%d errno=%d sslError=%d timeout=%d",
2453                   ssl, ret, errno, sslError, timeout);
2454 
2455         /*
2456          * If SSL_do_handshake doesn't succeed due to the socket being
2457          * either unreadable or unwritable, we use sslSelect to
2458          * wait for it to become ready. If that doesn't happen
2459          * before the specified timeout or an error occurs, we
2460          * cancel the handshake. Otherwise we try the SSL_connect
2461          * again.
2462          */
2463         if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
2464             appData->waitingThreads++;
2465             int selectResult = sslSelect(env, sslError, fdObject, appData, timeout);
2466 
2467             if (selectResult == THROWN_SOCKETEXCEPTION) {
2468                 // SocketException thrown by NetFd.isClosed
2469                 SSL_clear(ssl);
2470                 JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2471                 return 0;
2472             }
2473             if (selectResult == -1) {
2474                 throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_SYSCALL, "handshake error");
2475                 SSL_clear(ssl);
2476                 JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2477                 return 0;
2478             }
2479             if (selectResult == 0) {
2480                 throwSocketTimeoutException(env, "SSL handshake timed out");
2481                 SSL_clear(ssl);
2482                 freeSslErrorState();
2483                 JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2484                 return 0;
2485             }
2486         } else {
2487             // LOGE("Unknown error %d during handshake", error);
2488             break;
2489         }
2490     }
2491 
2492     // clean error. See SSL_do_handshake(3SSL) man page.
2493     if (ret == 0) {
2494         /*
2495          * The other side closed the socket before the handshake could be
2496          * completed, but everything is within the bounds of the TLS protocol.
2497          * We still might want to find out the real reason of the failure.
2498          */
2499         int sslError = SSL_get_error(ssl, ret);
2500         if (sslError == SSL_ERROR_NONE || (sslError == SSL_ERROR_SYSCALL && errno == 0)) {
2501             throwSSLExceptionStr(env, "Connection closed by peer");
2502         } else {
2503             throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL handshake terminated");
2504         }
2505         SSL_clear(ssl);
2506         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2507         return 0;
2508     }
2509 
2510     // unclean error. See SSL_do_handshake(3SSL) man page.
2511     if (ret < 0) {
2512         /*
2513          * Translate the error and throw exception. We are sure it is an error
2514          * at this point.
2515          */
2516         int sslError = SSL_get_error(ssl, ret);
2517         throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL handshake aborted");
2518         SSL_clear(ssl);
2519         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
2520         return 0;
2521     }
2522     SSL_SESSION* ssl_session = SSL_get1_session(ssl);
2523     JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => ssl_session=%p", ssl, ssl_session);
2524     return (jint) ssl_session;
2525 }
2526 
2527 /**
2528  * Perform SSL renegotiation
2529  */
NativeCrypto_SSL_renegotiate(JNIEnv * env,jclass,jint ssl_address)2530 static void NativeCrypto_SSL_renegotiate(JNIEnv* env, jclass, jint ssl_address)
2531 {
2532     SSL* ssl = to_SSL(env, ssl_address, true);
2533     JNI_TRACE("ssl=%p NativeCrypto_SSL_renegotiate", ssl);
2534     if (ssl == NULL) {
2535         return;
2536     }
2537     int result = SSL_renegotiate(ssl);
2538     if (result != 1) {
2539         throwSSLExceptionStr(env, "Problem with SSL_renegotiate");
2540         return;
2541     }
2542     int ret = SSL_do_handshake(ssl);
2543     if (ret != 1) {
2544         int sslError = SSL_get_error(ssl, ret);
2545         throwSSLExceptionWithSslErrors(env, ssl, sslError,
2546                                        "Problem with SSL_do_handshake after SSL_renegotiate");
2547 
2548         return;
2549     }
2550     JNI_TRACE("ssl=%p NativeCrypto_SSL_renegotiate =>", ssl);
2551 }
2552 
2553 /**
2554  * public static native byte[][] SSL_get_certificate(int ssl);
2555  */
NativeCrypto_SSL_get_certificate(JNIEnv * env,jclass,jint ssl_address)2556 static jobjectArray NativeCrypto_SSL_get_certificate(JNIEnv* env, jclass, jint ssl_address)
2557 {
2558     SSL* ssl = to_SSL(env, ssl_address, true);
2559     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate", ssl);
2560     if (ssl == NULL) {
2561         return NULL;
2562     }
2563     X509* certificate = SSL_get_certificate(ssl);
2564     if (certificate == NULL) {
2565         JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate => NULL", ssl);
2566         return NULL;
2567     }
2568 
2569     Unique_sk_X509 chain(sk_X509_new_null());
2570     if (chain.get() == NULL) {
2571         jniThrowOutOfMemoryError(env, "Unable to allocate local certificate chain");
2572         JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate => threw exception", ssl);
2573         return NULL;
2574     }
2575     if (!sk_X509_push(chain.get(), certificate)) {
2576         jniThrowOutOfMemoryError(env, "Unable to push local certificate");
2577         JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate => NULL", ssl);
2578         return NULL;
2579     }
2580     STACK_OF(X509)* cert_chain = SSL_get_certificate_chain(ssl, certificate);
2581     for (int i=0; i<sk_X509_num(cert_chain); i++) {
2582         if (!sk_X509_push(chain.get(), sk_X509_value(cert_chain, i))) {
2583             jniThrowOutOfMemoryError(env, "Unable to push local certificate chain");
2584             JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate => NULL", ssl);
2585             return NULL;
2586         }
2587     }
2588 
2589     jobjectArray objectArray = getCertificateBytes(env, chain.get());
2590     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_certificate => %p", ssl, objectArray);
2591     return objectArray;
2592 }
2593 
2594 // Fills a byte[][] with the peer certificates in the chain.
NativeCrypto_SSL_get_peer_cert_chain(JNIEnv * env,jclass,jint ssl_address)2595 static jobjectArray NativeCrypto_SSL_get_peer_cert_chain(JNIEnv* env, jclass, jint ssl_address)
2596 {
2597     SSL* ssl = to_SSL(env, ssl_address, true);
2598     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_peer_cert_chain", ssl);
2599     if (ssl == NULL) {
2600         return NULL;
2601     }
2602     STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
2603     Unique_sk_X509 chain_copy(NULL);
2604     if (ssl->server) {
2605         X509* x509 = SSL_get_peer_certificate(ssl);
2606         if (x509 == NULL) {
2607             JNI_TRACE("ssl=%p NativeCrypto_SSL_get_peer_cert_chain => NULL", ssl);
2608             return NULL;
2609         }
2610         chain_copy.reset(sk_X509_dup(chain));
2611         if (chain_copy.get() == NULL) {
2612             jniThrowOutOfMemoryError(env, "Unable to allocate peer certificate chain");
2613             JNI_TRACE("ssl=%p NativeCrypto_SSL_get_peer_cert_chain => certificate dup error", ssl);
2614             return NULL;
2615         }
2616         if (!sk_X509_push(chain_copy.get(), x509)) {
2617             jniThrowOutOfMemoryError(env, "Unable to push server's peer certificate");
2618             JNI_TRACE("ssl=%p NativeCrypto_SSL_get_peer_cert_chain => certificate push error", ssl);
2619             return NULL;
2620         }
2621         chain = chain_copy.get();
2622     }
2623     jobjectArray objectArray = getCertificateBytes(env, chain);
2624     JNI_TRACE("ssl=%p NativeCrypto_SSL_get_peer_cert_chain => %p", ssl, objectArray);
2625     return objectArray;
2626 }
2627 
2628 /**
2629  * Helper function which does the actual reading. The Java layer guarantees that
2630  * at most one thread will enter this function at any given time.
2631  *
2632  * @param ssl non-null; the SSL context
2633  * @param buf non-null; buffer to read into
2634  * @param len length of the buffer, in bytes
2635  * @param sslReturnCode original SSL return code
2636  * @param sslErrorCode filled in with the SSL error code in case of error
2637  * @return number of bytes read on success, -1 if the connection was
2638  * cleanly shut down, or THROW_EXCEPTION if an exception should be thrown.
2639  */
sslRead(JNIEnv * env,SSL * ssl,jobject fdObject,jobject shc,char * buf,jint len,int * sslReturnCode,int * sslErrorCode,int timeout)2640 static int sslRead(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, char* buf, jint len,
2641                    int* sslReturnCode, int* sslErrorCode, int timeout) {
2642 
2643     // LOGD("Entering sslRead, caller requests to read %d bytes...", len);
2644 
2645     if (len == 0) {
2646         // Don't bother doing anything in this case.
2647         return 0;
2648     }
2649 
2650     BIO* bio = SSL_get_rbio(ssl);
2651 
2652     AppData* appData = toAppData(ssl);
2653     if (appData == NULL) {
2654         return THROW_EXCEPTION;
2655     }
2656 
2657     while (appData->aliveAndKicking) {
2658         errno = 0;
2659 
2660         if (MUTEX_LOCK(appData->mutex) == -1) {
2661             return -1;
2662         }
2663 
2664         unsigned int bytesMoved = BIO_number_read(bio) + BIO_number_written(bio);
2665 
2666         // LOGD("Doing SSL_Read()");
2667         if (!appData->setCallbackState(env, shc, fdObject)) {
2668             MUTEX_UNLOCK(appData->mutex);
2669             return THROWN_SOCKETEXCEPTION;
2670         }
2671         int result = SSL_read(ssl, buf, len);
2672         appData->clearCallbackState();
2673         int sslError = SSL_ERROR_NONE;
2674         if (result <= 0) {
2675             sslError = SSL_get_error(ssl, result);
2676             freeSslErrorState();
2677         }
2678         // LOGD("Returned from SSL_Read() with result %d, error code %d", result, sslError);
2679 
2680         // If we have been successful in moving data around, check whether it
2681         // might make sense to wake up other blocked threads, so they can give
2682         // it a try, too.
2683         if (BIO_number_read(bio) + BIO_number_written(bio) != bytesMoved
2684                 && appData->waitingThreads > 0) {
2685             sslNotify(appData);
2686         }
2687 
2688         // If we are blocked by the underlying socket, tell the world that
2689         // there will be one more waiting thread now.
2690         if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
2691             appData->waitingThreads++;
2692         }
2693 
2694         MUTEX_UNLOCK(appData->mutex);
2695 
2696         switch (sslError) {
2697             // Successfully read at least one byte.
2698             case SSL_ERROR_NONE: {
2699                 return result;
2700             }
2701 
2702             // Read zero bytes. End of stream reached.
2703             case SSL_ERROR_ZERO_RETURN: {
2704                 return -1;
2705             }
2706 
2707             // Need to wait for availability of underlying layer, then retry.
2708             case SSL_ERROR_WANT_READ:
2709             case SSL_ERROR_WANT_WRITE: {
2710                 int selectResult = sslSelect(env, sslError, fdObject, appData, timeout);
2711                 if (selectResult == THROWN_SOCKETEXCEPTION) {
2712                     return THROWN_SOCKETEXCEPTION;
2713                 }
2714                 if (selectResult == -1) {
2715                     *sslReturnCode = -1;
2716                     *sslErrorCode = sslError;
2717                     return THROW_EXCEPTION;
2718                 }
2719                 if (selectResult == 0) {
2720                     return THROW_SOCKETTIMEOUTEXCEPTION;
2721                 }
2722 
2723                 break;
2724             }
2725 
2726             // A problem occurred during a system call, but this is not
2727             // necessarily an error.
2728             case SSL_ERROR_SYSCALL: {
2729                 // Connection closed without proper shutdown. Tell caller we
2730                 // have reached end-of-stream.
2731                 if (result == 0) {
2732                     return -1;
2733                 }
2734 
2735                 // System call has been interrupted. Simply retry.
2736                 if (errno == EINTR) {
2737                     break;
2738                 }
2739 
2740                 // Note that for all other system call errors we fall through
2741                 // to the default case, which results in an Exception.
2742             }
2743 
2744             // Everything else is basically an error.
2745             default: {
2746                 *sslReturnCode = result;
2747                 *sslErrorCode = sslError;
2748                 return THROW_EXCEPTION;
2749             }
2750         }
2751     }
2752 
2753     return -1;
2754 }
2755 
2756 /**
2757  * OpenSSL read function (1): only one chunk is read (returned as jint).
2758  */
NativeCrypto_SSL_read_byte(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc,jint timeout)2759 static jint NativeCrypto_SSL_read_byte(JNIEnv* env, jclass, jint ssl_address,
2760                                        jobject fdObject, jobject shc, jint timeout)
2761 {
2762     SSL* ssl = to_SSL(env, ssl_address, true);
2763     JNI_TRACE("ssl=%p NativeCrypto_SSL_read_byte fd=%p shc=%p timeout=%d",
2764               ssl, fdObject, shc, timeout);
2765     if (ssl == NULL) {
2766         return 0;
2767     }
2768     if (fdObject == NULL) {
2769         jniThrowNullPointerException(env, "fd == null");
2770         JNI_TRACE("ssl=%p NativeCrypto_SSL_read_byte => 0", ssl);
2771         return 0;
2772     }
2773     if (shc == NULL) {
2774         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
2775         JNI_TRACE("ssl=%p NativeCrypto_SSL_read_byte => 0", ssl);
2776         return 0;
2777     }
2778 
2779     unsigned char byteRead;
2780     int returnCode = 0;
2781     int sslErrorCode = SSL_ERROR_NONE;
2782 
2783     int ret = sslRead(env, ssl, fdObject, shc, reinterpret_cast<char*>(&byteRead), 1,
2784                       &returnCode, &sslErrorCode, timeout);
2785 
2786     int result;
2787     switch (ret) {
2788         case THROW_EXCEPTION:
2789             // See sslRead() regarding improper failure to handle normal cases.
2790             throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Read error");
2791             result = -1;
2792             break;
2793         case THROW_SOCKETTIMEOUTEXCEPTION:
2794             throwSocketTimeoutException(env, "Read timed out");
2795             result = -1;
2796             break;
2797         case THROWN_SOCKETEXCEPTION:
2798             // SocketException thrown by NetFd.isClosed
2799             result = -1;
2800             break;
2801         case -1:
2802             // Propagate EOF upwards.
2803             result = -1;
2804             break;
2805         default:
2806             // Return the actual char read, make sure it stays 8 bits wide.
2807             result = ((jint) byteRead) & 0xFF;
2808             break;
2809     }
2810     JNI_TRACE("ssl=%p NativeCrypto_SSL_read_byte => %d", ssl, result);
2811     return result;
2812 }
2813 
2814 /**
2815  * OpenSSL read function (2): read into buffer at offset n chunks.
2816  * Returns 1 (success) or value <= 0 (failure).
2817  */
NativeCrypto_SSL_read(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc,jbyteArray b,jint offset,jint len,jint timeout)2818 static jint NativeCrypto_SSL_read(JNIEnv* env, jclass, jint ssl_address, jobject fdObject,
2819                                   jobject shc, jbyteArray b, jint offset, jint len, jint timeout)
2820 {
2821     SSL* ssl = to_SSL(env, ssl_address, true);
2822     JNI_TRACE("ssl=%p NativeCrypto_SSL_read fd=%p shc=%p b=%p offset=%d len=%d timeout=%d",
2823               ssl, fdObject, shc, b, offset, len, timeout);
2824     if (ssl == NULL) {
2825         return 0;
2826     }
2827     if (fdObject == NULL) {
2828         jniThrowNullPointerException(env, "fd == null");
2829         JNI_TRACE("ssl=%p NativeCrypto_SSL_read => fd == null", ssl);
2830         return 0;
2831     }
2832     if (shc == NULL) {
2833         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
2834         JNI_TRACE("ssl=%p NativeCrypto_SSL_read => sslHandshakeCallbacks == null", ssl);
2835         return 0;
2836     }
2837 
2838     ScopedByteArrayRW bytes(env, b);
2839     if (bytes.get() == NULL) {
2840         JNI_TRACE("ssl=%p NativeCrypto_SSL_read => threw exception", ssl);
2841         return 0;
2842     }
2843     int returnCode = 0;
2844     int sslErrorCode = SSL_ERROR_NONE;;
2845 
2846     int ret = sslRead(env, ssl, fdObject, shc, reinterpret_cast<char*>(bytes.get() + offset), len,
2847                       &returnCode, &sslErrorCode, timeout);
2848 
2849     int result;
2850     switch (ret) {
2851         case THROW_EXCEPTION:
2852             // See sslRead() regarding improper failure to handle normal cases.
2853             throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Read error");
2854             result = -1;
2855             break;
2856         case THROW_SOCKETTIMEOUTEXCEPTION:
2857             throwSocketTimeoutException(env, "Read timed out");
2858             result = -1;
2859             break;
2860         case THROWN_SOCKETEXCEPTION:
2861             // SocketException thrown by NetFd.isClosed
2862             result = -1;
2863             break;
2864         default:
2865             result = ret;
2866             break;
2867     }
2868 
2869     JNI_TRACE("ssl=%p NativeCrypto_SSL_read => %d", ssl, result);
2870     return result;
2871 }
2872 
2873 /**
2874  * Helper function which does the actual writing. The Java layer guarantees that
2875  * at most one thread will enter this function at any given time.
2876  *
2877  * @param ssl non-null; the SSL context
2878  * @param buf non-null; buffer to write
2879  * @param len length of the buffer, in bytes
2880  * @param sslReturnCode original SSL return code
2881  * @param sslErrorCode filled in with the SSL error code in case of error
2882  * @return number of bytes read on success, -1 if the connection was
2883  * cleanly shut down, or THROW_EXCEPTION if an exception should be thrown.
2884  */
sslWrite(JNIEnv * env,SSL * ssl,jobject fdObject,jobject shc,const char * buf,jint len,int * sslReturnCode,int * sslErrorCode)2885 static int sslWrite(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, const char* buf, jint len,
2886                     int* sslReturnCode, int* sslErrorCode) {
2887 
2888     // LOGD("Entering sslWrite(), caller requests to write %d bytes...", len);
2889 
2890     if (len == 0) {
2891         // Don't bother doing anything in this case.
2892         return 0;
2893     }
2894 
2895     BIO* bio = SSL_get_wbio(ssl);
2896 
2897     AppData* appData = toAppData(ssl);
2898     if (appData == NULL) {
2899         return THROW_EXCEPTION;
2900     }
2901 
2902     int count = len;
2903 
2904     while (appData->aliveAndKicking && len > 0) {
2905         errno = 0;
2906 
2907         if (MUTEX_LOCK(appData->mutex) == -1) {
2908             return -1;
2909         }
2910 
2911         unsigned int bytesMoved = BIO_number_read(bio) + BIO_number_written(bio);
2912 
2913         // LOGD("Doing SSL_write() with %d bytes to go", len);
2914         if (!appData->setCallbackState(env, shc, fdObject)) {
2915             MUTEX_UNLOCK(appData->mutex);
2916             return THROWN_SOCKETEXCEPTION;
2917         }
2918         int result = SSL_write(ssl, buf, len);
2919         appData->clearCallbackState();
2920         int sslError = SSL_ERROR_NONE;
2921         if (result <= 0) {
2922             sslError = SSL_get_error(ssl, result);
2923             freeSslErrorState();
2924         }
2925         // LOGD("Returned from SSL_write() with result %d, error code %d", result, error);
2926 
2927         // If we have been successful in moving data around, check whether it
2928         // might make sense to wake up other blocked threads, so they can give
2929         // it a try, too.
2930         if (BIO_number_read(bio) + BIO_number_written(bio) != bytesMoved
2931                 && appData->waitingThreads > 0) {
2932             sslNotify(appData);
2933         }
2934 
2935         // If we are blocked by the underlying socket, tell the world that
2936         // there will be one more waiting thread now.
2937         if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
2938             appData->waitingThreads++;
2939         }
2940 
2941         MUTEX_UNLOCK(appData->mutex);
2942 
2943         switch (sslError) {
2944             // Successfully wrote at least one byte.
2945             case SSL_ERROR_NONE: {
2946                 buf += result;
2947                 len -= result;
2948                 break;
2949             }
2950 
2951             // Wrote zero bytes. End of stream reached.
2952             case SSL_ERROR_ZERO_RETURN: {
2953                 return -1;
2954             }
2955 
2956             // Need to wait for availability of underlying layer, then retry.
2957             // The concept of a write timeout doesn't really make sense, and
2958             // it's also not standard Java behavior, so we wait forever here.
2959             case SSL_ERROR_WANT_READ:
2960             case SSL_ERROR_WANT_WRITE: {
2961                 int selectResult = sslSelect(env, sslError, fdObject, appData, 0);
2962                 if (selectResult == THROWN_SOCKETEXCEPTION) {
2963                     return THROWN_SOCKETEXCEPTION;
2964                 }
2965                 if (selectResult == -1) {
2966                     *sslReturnCode = -1;
2967                     *sslErrorCode = sslError;
2968                     return THROW_EXCEPTION;
2969                 }
2970                 if (selectResult == 0) {
2971                     return THROW_SOCKETTIMEOUTEXCEPTION;
2972                 }
2973 
2974                 break;
2975             }
2976 
2977             // A problem occurred during a system call, but this is not
2978             // necessarily an error.
2979             case SSL_ERROR_SYSCALL: {
2980                 // Connection closed without proper shutdown. Tell caller we
2981                 // have reached end-of-stream.
2982                 if (result == 0) {
2983                     return -1;
2984                 }
2985 
2986                 // System call has been interrupted. Simply retry.
2987                 if (errno == EINTR) {
2988                     break;
2989                 }
2990 
2991                 // Note that for all other system call errors we fall through
2992                 // to the default case, which results in an Exception.
2993             }
2994 
2995             // Everything else is basically an error.
2996             default: {
2997                 *sslReturnCode = result;
2998                 *sslErrorCode = sslError;
2999                 return THROW_EXCEPTION;
3000             }
3001         }
3002     }
3003     // LOGD("Successfully wrote %d bytes", count);
3004 
3005     return count;
3006 }
3007 
3008 /**
3009  * OpenSSL write function (1): only one chunk is written.
3010  */
NativeCrypto_SSL_write_byte(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc,jint b)3011 static void NativeCrypto_SSL_write_byte(JNIEnv* env, jclass, jint ssl_address,
3012                                         jobject fdObject, jobject shc, jint b)
3013 {
3014     SSL* ssl = to_SSL(env, ssl_address, true);
3015     JNI_TRACE("ssl=%p NativeCrypto_SSL_write_byte fd=%p shc=%p b=%d", ssl, fdObject, shc, b);
3016     if (ssl == NULL) {
3017         return;
3018     }
3019     if (fdObject == NULL) {
3020         jniThrowNullPointerException(env, "fd == null");
3021         JNI_TRACE("ssl=%p NativeCrypto_SSL_write_byte => fd == null", ssl);
3022         return;
3023     }
3024     if (shc == NULL) {
3025         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
3026         JNI_TRACE("ssl=%p NativeCrypto_SSL_write_byte => sslHandshakeCallbacks == null", ssl);
3027         return;
3028     }
3029 
3030     int returnCode = 0;
3031     int sslErrorCode = SSL_ERROR_NONE;
3032     char buf[1] = { (char) b };
3033     int ret = sslWrite(env, ssl, fdObject, shc, buf, 1, &returnCode, &sslErrorCode);
3034 
3035     switch (ret) {
3036         case THROW_EXCEPTION:
3037             // See sslWrite() regarding improper failure to handle normal cases.
3038             throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
3039             break;
3040         case THROW_SOCKETTIMEOUTEXCEPTION:
3041             throwSocketTimeoutException(env, "Write timed out");
3042             break;
3043         case THROWN_SOCKETEXCEPTION:
3044             // SocketException thrown by NetFd.isClosed
3045             break;
3046         default:
3047             break;
3048     }
3049 }
3050 
3051 /**
3052  * OpenSSL write function (2): write into buffer at offset n chunks.
3053  */
NativeCrypto_SSL_write(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc,jbyteArray b,jint offset,jint len)3054 static void NativeCrypto_SSL_write(JNIEnv* env, jclass, jint ssl_address, jobject fdObject,
3055                                    jobject shc, jbyteArray b, jint offset, jint len)
3056 {
3057     SSL* ssl = to_SSL(env, ssl_address, true);
3058     JNI_TRACE("ssl=%p NativeCrypto_SSL_write fd=%p shc=%p b=%p offset=%d len=%d",
3059               ssl, fdObject, shc, b, offset, len);
3060     if (ssl == NULL) {
3061         return;
3062     }
3063     if (fdObject == NULL) {
3064         jniThrowNullPointerException(env, "fd == null");
3065         JNI_TRACE("ssl=%p NativeCrypto_SSL_write => fd == null", ssl);
3066         return;
3067     }
3068     if (shc == NULL) {
3069         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
3070         JNI_TRACE("ssl=%p NativeCrypto_SSL_write => sslHandshakeCallbacks == null", ssl);
3071         return;
3072     }
3073 
3074     ScopedByteArrayRO bytes(env, b);
3075     if (bytes.get() == NULL) {
3076         JNI_TRACE("ssl=%p NativeCrypto_SSL_write => threw exception", ssl);
3077         return;
3078     }
3079     int returnCode = 0;
3080     int sslErrorCode = SSL_ERROR_NONE;
3081     int ret = sslWrite(env, ssl, fdObject, shc, reinterpret_cast<const char*>(bytes.get() + offset),
3082                        len, &returnCode, &sslErrorCode);
3083 
3084     switch (ret) {
3085         case THROW_EXCEPTION:
3086             // See sslWrite() regarding improper failure to handle normal cases.
3087             throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
3088             break;
3089         case THROW_SOCKETTIMEOUTEXCEPTION:
3090             throwSocketTimeoutException(env, "Write timed out");
3091             break;
3092         case THROWN_SOCKETEXCEPTION:
3093             // SocketException thrown by NetFd.isClosed
3094             break;
3095         default:
3096             break;
3097     }
3098 }
3099 
3100 /**
3101  * Interrupt any pending IO before closing the socket.
3102  */
NativeCrypto_SSL_interrupt(JNIEnv * env,jclass,jint ssl_address)3103 static void NativeCrypto_SSL_interrupt(
3104         JNIEnv* env, jclass, jint ssl_address) {
3105     SSL* ssl = to_SSL(env, ssl_address, false);
3106     JNI_TRACE("ssl=%p NativeCrypto_SSL_interrupt", ssl);
3107     if (ssl == NULL) {
3108         return;
3109     }
3110 
3111     /*
3112      * Mark the connection as quasi-dead, then send something to the emergency
3113      * file descriptor, so any blocking select() calls are woken up.
3114      */
3115     AppData* appData = toAppData(ssl);
3116     if (appData != NULL) {
3117         appData->aliveAndKicking = 0;
3118 
3119         // At most two threads can be waiting.
3120         sslNotify(appData);
3121         sslNotify(appData);
3122     }
3123 }
3124 
3125 /**
3126  * OpenSSL close SSL socket function.
3127  */
NativeCrypto_SSL_shutdown(JNIEnv * env,jclass,jint ssl_address,jobject fdObject,jobject shc)3128 static void NativeCrypto_SSL_shutdown(JNIEnv* env, jclass, jint ssl_address,
3129                                       jobject fdObject, jobject shc) {
3130     SSL* ssl = to_SSL(env, ssl_address, false);
3131     JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown fd=%p shc=%p", ssl, fdObject, shc);
3132     if (ssl == NULL) {
3133         return;
3134     }
3135     if (fdObject == NULL) {
3136         jniThrowNullPointerException(env, "fd == null");
3137         JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown => fd == null", ssl);
3138         return;
3139     }
3140     if (shc == NULL) {
3141         jniThrowNullPointerException(env, "sslHandshakeCallbacks == null");
3142         JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown => sslHandshakeCallbacks == null", ssl);
3143         return;
3144     }
3145 
3146     AppData* appData = toAppData(ssl);
3147     if (appData != NULL) {
3148         if (!appData->setCallbackState(env, shc, fdObject)) {
3149             // SocketException thrown by NetFd.isClosed
3150             SSL_clear(ssl);
3151             freeSslErrorState();
3152             return;
3153         }
3154 
3155         /*
3156          * Try to make socket blocking again. OpenSSL literature recommends this.
3157          */
3158         int fd = SSL_get_fd(ssl);
3159         JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown s=%d", ssl, fd);
3160         if (fd != -1) {
3161             setBlocking(fd, true);
3162         }
3163 
3164         int ret = SSL_shutdown(ssl);
3165         switch (ret) {
3166             case 0:
3167                 /*
3168                  * Shutdown was not successful (yet), but there also
3169                  * is no error. Since we can't know whether the remote
3170                  * server is actually still there, and we don't want to
3171                  * get stuck forever in a second SSL_shutdown() call, we
3172                  * simply return. This is not security a problem as long
3173                  * as we close the underlying socket, which we actually
3174                  * do, because that's where we are just coming from.
3175                  */
3176                 break;
3177             case 1:
3178                 /*
3179                  * Shutdown was successful. We can safely return. Hooray!
3180                  */
3181                 break;
3182             default:
3183                 /*
3184                  * Everything else is a real error condition. We should
3185                  * let the Java layer know about this by throwing an
3186                  * exception.
3187                  */
3188                 int sslError = SSL_get_error(ssl, ret);
3189                 throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL shutdown failed");
3190                 break;
3191         }
3192         appData->clearCallbackState();
3193     }
3194 
3195     SSL_clear(ssl);
3196     freeSslErrorState();
3197 }
3198 
3199 /**
3200  * public static native void SSL_free(int ssl);
3201  */
NativeCrypto_SSL_free(JNIEnv * env,jclass,jint ssl_address)3202 static void NativeCrypto_SSL_free(JNIEnv* env, jclass, jint ssl_address)
3203 {
3204     SSL* ssl = to_SSL(env, ssl_address, true);
3205     JNI_TRACE("ssl=%p NativeCrypto_SSL_free", ssl);
3206     if (ssl == NULL) {
3207         return;
3208     }
3209 
3210     AppData* appData = toAppData(ssl);
3211     SSL_set_app_data(ssl, NULL);
3212     delete appData;
3213     SSL_free(ssl);
3214 }
3215 
3216 /**
3217  * Gets and returns in a byte array the ID of the actual SSL session.
3218  */
NativeCrypto_SSL_SESSION_session_id(JNIEnv * env,jclass,jint ssl_session_address)3219 static jbyteArray NativeCrypto_SSL_SESSION_session_id(JNIEnv* env, jclass,
3220                                                       jint ssl_session_address) {
3221     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3222     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_session_id", ssl_session);
3223     if (ssl_session == NULL) {
3224         return NULL;
3225     }
3226     jbyteArray result = env->NewByteArray(ssl_session->session_id_length);
3227     if (result != NULL) {
3228         jbyte* src = reinterpret_cast<jbyte*>(ssl_session->session_id);
3229         env->SetByteArrayRegion(result, 0, ssl_session->session_id_length, src);
3230     }
3231     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_session_id => %p session_id_length=%d",
3232              ssl_session, result, ssl_session->session_id_length);
3233     return result;
3234 }
3235 
3236 /**
3237  * Gets and returns in a long integer the creation's time of the
3238  * actual SSL session.
3239  */
NativeCrypto_SSL_SESSION_get_time(JNIEnv * env,jclass,jint ssl_session_address)3240 static jlong NativeCrypto_SSL_SESSION_get_time(JNIEnv* env, jclass, jint ssl_session_address) {
3241     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3242     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_get_time", ssl_session);
3243     if (ssl_session == NULL) {
3244         return 0;
3245     }
3246     // result must be jlong, not long or *1000 will overflow
3247     jlong result = SSL_SESSION_get_time(ssl_session);
3248     result *= 1000; // OpenSSL uses seconds, Java uses milliseconds.
3249     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_get_time => %lld", ssl_session, result);
3250     return result;
3251 }
3252 
3253 /**
3254  * Our implementation of what might be considered
3255  * SSL_SESSION_get_version, based on SSL_get_version.
3256  * See get_ssl_version above.
3257  */
3258 // TODO move to jsse.patch
SSL_SESSION_get_version(SSL_SESSION * ssl_session)3259 static const char* SSL_SESSION_get_version(SSL_SESSION* ssl_session) {
3260   return get_ssl_version(ssl_session->ssl_version);
3261 }
3262 
3263 /**
3264  * Gets and returns in a string the version of the SSL protocol. If it
3265  * returns the string "unknown" it means that no connection is established.
3266  */
NativeCrypto_SSL_SESSION_get_version(JNIEnv * env,jclass,jint ssl_session_address)3267 static jstring NativeCrypto_SSL_SESSION_get_version(JNIEnv* env, jclass, jint ssl_session_address) {
3268     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3269     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_get_version", ssl_session);
3270     if (ssl_session == NULL) {
3271         return NULL;
3272     }
3273     const char* protocol = SSL_SESSION_get_version(ssl_session);
3274     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_get_version => %s", ssl_session, protocol);
3275     return env->NewStringUTF(protocol);
3276 }
3277 
3278 /**
3279  * Gets and returns in a string the cipher negotiated for the SSL session.
3280  */
NativeCrypto_SSL_SESSION_cipher(JNIEnv * env,jclass,jint ssl_session_address)3281 static jstring NativeCrypto_SSL_SESSION_cipher(JNIEnv* env, jclass, jint ssl_session_address) {
3282     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3283     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_cipher", ssl_session);
3284     if (ssl_session == NULL) {
3285         return NULL;
3286     }
3287     const SSL_CIPHER* cipher = ssl_session->cipher;
3288     const char* name = SSL_CIPHER_get_name(cipher);
3289     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_cipher => %s", ssl_session, name);
3290     return env->NewStringUTF(name);
3291 }
3292 
3293 /**
3294  * Gets and returns in a string the compression method negotiated for the SSL session.
3295  */
NativeCrypto_SSL_SESSION_compress_meth(JNIEnv * env,jclass,jint ssl_ctx_address,jint ssl_session_address)3296 static jstring NativeCrypto_SSL_SESSION_compress_meth(JNIEnv* env, jclass,
3297                                                       jint ssl_ctx_address,
3298                                                       jint ssl_session_address) {
3299     SSL_CTX* ssl_ctx = to_SSL_CTX(env, ssl_ctx_address, true);
3300     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3301     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_compress_meth ssl_ctx=%p",
3302               ssl_session, ssl_ctx);
3303     if (ssl_ctx == NULL || ssl_session == NULL) {
3304         return NULL;
3305     }
3306 
3307     int compress_meth = ssl_session->compress_meth;
3308     if (compress_meth == 0) {
3309         const char* name = "NULL";
3310         JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_compress_meth => %s", ssl_session, name);
3311         return env->NewStringUTF(name);
3312     }
3313 
3314     int num_comp_methods = sk_SSL_COMP_num(ssl_ctx->comp_methods);
3315     for (int i = 0; i < num_comp_methods; i++) {
3316         SSL_COMP* comp = sk_SSL_COMP_value(ssl_ctx->comp_methods, i);
3317         if (comp->id != compress_meth) {
3318             continue;
3319         }
3320         const char* name = ((comp->method && comp->method->type == NID_zlib_compression)
3321                             ? SN_zlib_compression
3322                             : (comp->name ? comp->name : "UNKNOWN"));
3323         JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_compress_meth => %s", ssl_session, name);
3324         return env->NewStringUTF(name);
3325     }
3326     throwSSLExceptionStr(env, "Unknown compression method");
3327     return NULL;
3328 }
3329 
3330 /**
3331  * Frees the SSL session.
3332  */
NativeCrypto_SSL_SESSION_free(JNIEnv * env,jclass,jint ssl_session_address)3333 static void NativeCrypto_SSL_SESSION_free(JNIEnv* env, jclass, jint ssl_session_address) {
3334     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3335     JNI_TRACE("ssl_session=%p NativeCrypto_SSL_SESSION_free", ssl_session);
3336     if (ssl_session == NULL) {
3337         return;
3338     }
3339     SSL_SESSION_free(ssl_session);
3340 }
3341 
3342 
3343 /**
3344  * Serializes the native state of the session (ID, cipher, and keys but
3345  * not certificates). Returns a byte[] containing the DER-encoded state.
3346  * See apache mod_ssl.
3347  */
NativeCrypto_i2d_SSL_SESSION(JNIEnv * env,jclass,jint ssl_session_address)3348 static jbyteArray NativeCrypto_i2d_SSL_SESSION(JNIEnv* env, jclass, jint ssl_session_address) {
3349     SSL_SESSION* ssl_session = to_SSL_SESSION(env, ssl_session_address, true);
3350     JNI_TRACE("ssl_session=%p NativeCrypto_i2d_SSL_SESSION", ssl_session);
3351     if (ssl_session == NULL) {
3352         return NULL;
3353     }
3354 
3355     // Compute the size of the DER data
3356     int size = i2d_SSL_SESSION(ssl_session, NULL);
3357     if (size == 0) {
3358         JNI_TRACE("ssl_session=%p NativeCrypto_i2d_SSL_SESSION => NULL", ssl_session);
3359         return NULL;
3360     }
3361 
3362     jbyteArray javaBytes = env->NewByteArray(size);
3363     if (javaBytes != NULL) {
3364         ScopedByteArrayRW bytes(env, javaBytes);
3365         if (bytes.get() == NULL) {
3366             JNI_TRACE("ssl_session=%p NativeCrypto_i2d_SSL_SESSION => threw exception",
3367                       ssl_session);
3368             return NULL;
3369         }
3370         unsigned char* ucp = reinterpret_cast<unsigned char*>(bytes.get());
3371         i2d_SSL_SESSION(ssl_session, &ucp);
3372     }
3373 
3374     JNI_TRACE("ssl_session=%p NativeCrypto_i2d_SSL_SESSION => size=%d", ssl_session, size);
3375     return javaBytes;
3376 }
3377 
3378 /**
3379  * Deserialize the session.
3380  */
NativeCrypto_d2i_SSL_SESSION(JNIEnv * env,jclass,jbyteArray javaBytes)3381 static jint NativeCrypto_d2i_SSL_SESSION(JNIEnv* env, jclass, jbyteArray javaBytes) {
3382     JNI_TRACE("NativeCrypto_d2i_SSL_SESSION bytes=%p", javaBytes);
3383 
3384     ScopedByteArrayRO bytes(env, javaBytes);
3385     if (bytes.get() == NULL) {
3386         JNI_TRACE("NativeCrypto_d2i_SSL_SESSION => threw exception");
3387         return 0;
3388     }
3389     const unsigned char* ucp = reinterpret_cast<const unsigned char*>(bytes.get());
3390     SSL_SESSION* ssl_session = d2i_SSL_SESSION(NULL, &ucp, bytes.size());
3391 
3392     JNI_TRACE("NativeCrypto_d2i_SSL_SESSION => %p", ssl_session);
3393     return static_cast<jint>(reinterpret_cast<uintptr_t>(ssl_session));
3394 }
3395 
3396 #define FILE_DESCRIPTOR "Ljava/io/FileDescriptor;"
3397 #define SSL_CALLBACKS "Lorg/apache/harmony/xnet/provider/jsse/NativeCrypto$SSLHandshakeCallbacks;"
3398 static JNINativeMethod sNativeCryptoMethods[] = {
3399     NATIVE_METHOD(NativeCrypto, clinit, "()V"),
3400     NATIVE_METHOD(NativeCrypto, EVP_PKEY_new_DSA, "([B[B[B[B[B)I"),
3401     NATIVE_METHOD(NativeCrypto, EVP_PKEY_new_RSA, "([B[B[B[B[B)I"),
3402     NATIVE_METHOD(NativeCrypto, EVP_PKEY_free, "(I)V"),
3403     NATIVE_METHOD(NativeCrypto, EVP_MD_CTX_create, "()I"),
3404     NATIVE_METHOD(NativeCrypto, EVP_MD_CTX_destroy, "(I)V"),
3405     NATIVE_METHOD(NativeCrypto, EVP_MD_CTX_copy, "(I)I"),
3406     NATIVE_METHOD(NativeCrypto, EVP_DigestFinal, "(I[BI)I"),
3407     NATIVE_METHOD(NativeCrypto, EVP_DigestInit, "(ILjava/lang/String;)V"),
3408     NATIVE_METHOD(NativeCrypto, EVP_MD_CTX_block_size, "(I)I"),
3409     NATIVE_METHOD(NativeCrypto, EVP_MD_CTX_size, "(I)I"),
3410     NATIVE_METHOD(NativeCrypto, EVP_DigestUpdate, "(I[BII)V"),
3411     NATIVE_METHOD(NativeCrypto, EVP_VerifyInit, "(ILjava/lang/String;)V"),
3412     NATIVE_METHOD(NativeCrypto, EVP_VerifyUpdate, "(I[BII)V"),
3413     NATIVE_METHOD(NativeCrypto, EVP_VerifyFinal, "(I[BIII)I"),
3414     NATIVE_METHOD(NativeCrypto, verifySignature, "([B[BLjava/lang/String;[B[B)I"),
3415     NATIVE_METHOD(NativeCrypto, RAND_seed, "([B)V"),
3416     NATIVE_METHOD(NativeCrypto, RAND_load_file, "(Ljava/lang/String;J)I"),
3417     NATIVE_METHOD(NativeCrypto, SSL_CTX_new, "()I"),
3418     NATIVE_METHOD(NativeCrypto, SSL_CTX_free, "(I)V"),
3419     NATIVE_METHOD(NativeCrypto, SSL_new, "(I)I"),
3420     NATIVE_METHOD(NativeCrypto, SSL_use_PrivateKey, "(I[B)V"),
3421     NATIVE_METHOD(NativeCrypto, SSL_use_certificate, "(I[[B)V"),
3422     NATIVE_METHOD(NativeCrypto, SSL_check_private_key, "(I)V"),
3423     NATIVE_METHOD(NativeCrypto, SSL_set_client_CA_list, "(I[[B)V"),
3424     NATIVE_METHOD(NativeCrypto, SSL_get_mode, "(I)J"),
3425     NATIVE_METHOD(NativeCrypto, SSL_set_mode, "(IJ)J"),
3426     NATIVE_METHOD(NativeCrypto, SSL_clear_mode, "(IJ)J"),
3427     NATIVE_METHOD(NativeCrypto, SSL_get_options, "(I)J"),
3428     NATIVE_METHOD(NativeCrypto, SSL_set_options, "(IJ)J"),
3429     NATIVE_METHOD(NativeCrypto, SSL_clear_options, "(IJ)J"),
3430     NATIVE_METHOD(NativeCrypto, SSL_set_cipher_lists, "(I[Ljava/lang/String;)V"),
3431     NATIVE_METHOD(NativeCrypto, SSL_set_verify, "(II)V"),
3432     NATIVE_METHOD(NativeCrypto, SSL_set_session, "(II)V"),
3433     NATIVE_METHOD(NativeCrypto, SSL_set_session_creation_enabled, "(IZ)V"),
3434     NATIVE_METHOD(NativeCrypto, SSL_set_tlsext_host_name, "(ILjava/lang/String;)V"),
3435     NATIVE_METHOD(NativeCrypto, SSL_get_servername, "(I)Ljava/lang/String;"),
3436     NATIVE_METHOD(NativeCrypto, SSL_do_handshake, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "IZ)I"),
3437     NATIVE_METHOD(NativeCrypto, SSL_renegotiate, "(I)V"),
3438     NATIVE_METHOD(NativeCrypto, SSL_get_certificate, "(I)[[B"),
3439     NATIVE_METHOD(NativeCrypto, SSL_get_peer_cert_chain, "(I)[[B"),
3440     NATIVE_METHOD(NativeCrypto, SSL_read_byte, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "I)I"),
3441     NATIVE_METHOD(NativeCrypto, SSL_read, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "[BIII)I"),
3442     NATIVE_METHOD(NativeCrypto, SSL_write_byte, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "I)V"),
3443     NATIVE_METHOD(NativeCrypto, SSL_write, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "[BII)V"),
3444     NATIVE_METHOD(NativeCrypto, SSL_interrupt, "(I)V"),
3445     NATIVE_METHOD(NativeCrypto, SSL_shutdown, "(I" FILE_DESCRIPTOR SSL_CALLBACKS ")V"),
3446     NATIVE_METHOD(NativeCrypto, SSL_free, "(I)V"),
3447     NATIVE_METHOD(NativeCrypto, SSL_SESSION_session_id, "(I)[B"),
3448     NATIVE_METHOD(NativeCrypto, SSL_SESSION_get_time, "(I)J"),
3449     NATIVE_METHOD(NativeCrypto, SSL_SESSION_get_version, "(I)Ljava/lang/String;"),
3450     NATIVE_METHOD(NativeCrypto, SSL_SESSION_cipher, "(I)Ljava/lang/String;"),
3451     NATIVE_METHOD(NativeCrypto, SSL_SESSION_compress_meth, "(II)Ljava/lang/String;"),
3452     NATIVE_METHOD(NativeCrypto, SSL_SESSION_free, "(I)V"),
3453     NATIVE_METHOD(NativeCrypto, i2d_SSL_SESSION, "(I)[B"),
3454     NATIVE_METHOD(NativeCrypto, d2i_SSL_SESSION, "([B)I"),
3455 };
3456 
register_org_apache_harmony_xnet_provider_jsse_NativeCrypto(JNIEnv * env)3457 int register_org_apache_harmony_xnet_provider_jsse_NativeCrypto(JNIEnv* env) {
3458     JNI_TRACE("register_org_apache_harmony_xnet_provider_jsse_NativeCrypto");
3459     // Register org.apache.harmony.xnet.provider.jsse.NativeCrypto methods
3460     return jniRegisterNativeMethods(env,
3461                                     "org/apache/harmony/xnet/provider/jsse/NativeCrypto",
3462                                     sNativeCryptoMethods,
3463                                     NELEM(sNativeCryptoMethods));
3464 }
3465