1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * fs/hmdfs/comm/crypto.c
4 *
5 * Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6 */
7
8 #include "crypto.h"
9
10 #include <crypto/aead.h>
11 #include <crypto/hash.h>
12 #include <linux/tcp.h>
13 #include <net/inet_connection_sock.h>
14 #include <net/tcp_states.h>
15 #include <net/tls.h>
16
17 #include "hmdfs.h"
18
tls_crypto_set_key(struct connection * conn_impl,int tx)19 static void tls_crypto_set_key(struct connection *conn_impl, int tx)
20 {
21 int rc = 0;
22 struct tcp_handle *tcp = conn_impl->connect_handle;
23 struct tls_context *ctx = NULL;
24 struct cipher_context *cctx = NULL;
25 struct tls_sw_context_tx *sw_ctx_tx = NULL;
26 struct tls_sw_context_rx *sw_ctx_rx = NULL;
27 struct crypto_aead **aead = NULL;
28 struct tls12_crypto_info_aes_gcm_128 *crypto_info = NULL;
29
30 lock_sock(tcp->sock->sk);
31 ctx = tls_get_ctx(tcp->sock->sk);
32 if (tx) {
33 crypto_info = &conn_impl->send_crypto_info;
34 cctx = &ctx->tx;
35 sw_ctx_tx = tls_sw_ctx_tx(ctx);
36 aead = &sw_ctx_tx->aead_send;
37 } else {
38 crypto_info = &conn_impl->recv_crypto_info;
39 cctx = &ctx->rx;
40 sw_ctx_rx = tls_sw_ctx_rx(ctx);
41 aead = &sw_ctx_rx->aead_recv;
42 }
43
44 memcpy(cctx->iv, crypto_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
45 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, crypto_info->iv,
46 TLS_CIPHER_AES_GCM_128_IV_SIZE);
47 memcpy(cctx->rec_seq, crypto_info->rec_seq,
48 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
49 rc = crypto_aead_setkey(*aead, crypto_info->key,
50 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
51 if (rc)
52 hmdfs_err("crypto set key error");
53 release_sock(tcp->sock->sk);
54 }
55
tls_crypto_info_init(struct connection * conn_impl)56 int tls_crypto_info_init(struct connection *conn_impl)
57 {
58 int ret = 0;
59 u8 key_meterial[HMDFS_KEY_SIZE];
60 struct tcp_handle *tcp =
61 (struct tcp_handle *)(conn_impl->connect_handle);
62 if (!tcp)
63 return -EINVAL;
64 // send
65 update_key(conn_impl->send_key, key_meterial, HKDF_TYPE_IV);
66 ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TCP, TCP_ULP,
67 KERNEL_SOCKPTR("tls"), sizeof("tls"));
68 if (ret)
69 hmdfs_err("set tls error %d", ret);
70 tcp->connect->send_crypto_info.info.version = TLS_1_2_VERSION;
71 tcp->connect->send_crypto_info.info.cipher_type =
72 TLS_CIPHER_AES_GCM_128;
73
74 memcpy(tcp->connect->send_crypto_info.key, tcp->connect->send_key,
75 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
76 memcpy(tcp->connect->send_crypto_info.iv,
77 key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
78 memcpy(tcp->connect->send_crypto_info.salt,
79 key_meterial + CRYPTO_SALT_OFFSET,
80 TLS_CIPHER_AES_GCM_128_SALT_SIZE);
81 memcpy(tcp->connect->send_crypto_info.rec_seq,
82 key_meterial + CRYPTO_SEQ_OFFSET,
83 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
84
85 ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TLS, TLS_TX,
86 KERNEL_SOCKPTR(&(tcp->connect->send_crypto_info)),
87 sizeof(tcp->connect->send_crypto_info));
88 if (ret)
89 hmdfs_err("set tls send_crypto_info error %d", ret);
90
91 // recv
92 update_key(tcp->connect->recv_key, key_meterial, HKDF_TYPE_IV);
93 tcp->connect->recv_crypto_info.info.version = TLS_1_2_VERSION;
94 tcp->connect->recv_crypto_info.info.cipher_type =
95 TLS_CIPHER_AES_GCM_128;
96
97 memcpy(tcp->connect->recv_crypto_info.key, tcp->connect->recv_key,
98 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
99 memcpy(tcp->connect->recv_crypto_info.iv,
100 key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
101 memcpy(tcp->connect->recv_crypto_info.salt,
102 key_meterial + CRYPTO_SALT_OFFSET,
103 TLS_CIPHER_AES_GCM_128_SALT_SIZE);
104 memcpy(tcp->connect->recv_crypto_info.rec_seq,
105 key_meterial + CRYPTO_SEQ_OFFSET,
106 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
107 memset(key_meterial, 0, HMDFS_KEY_SIZE);
108
109 ret = tcp->sock->ops->setsockopt(tcp->sock, SOL_TLS, TLS_RX,
110 KERNEL_SOCKPTR(&(tcp->connect->recv_crypto_info)),
111 sizeof(tcp->connect->recv_crypto_info));
112 if (ret)
113 hmdfs_err("set tls recv_crypto_info error %d", ret);
114 return ret;
115 }
116
tls_set_tx(struct tcp_handle * tcp)117 static int tls_set_tx(struct tcp_handle *tcp)
118 {
119 int ret = 0;
120 u8 new_key[HMDFS_KEY_SIZE];
121 u8 key_meterial[HMDFS_KEY_SIZE];
122
123 ret = update_key(tcp->connect->send_key, new_key, HKDF_TYPE_REKEY);
124 if (ret < 0)
125 return ret;
126 memcpy(tcp->connect->send_key, new_key, HMDFS_KEY_SIZE);
127 ret = update_key(tcp->connect->send_key, key_meterial, HKDF_TYPE_IV);
128 if (ret < 0)
129 return ret;
130
131 memcpy(tcp->connect->send_crypto_info.key, tcp->connect->send_key,
132 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
133 memcpy(tcp->connect->send_crypto_info.iv,
134 key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
135 memcpy(tcp->connect->send_crypto_info.salt,
136 key_meterial + CRYPTO_SALT_OFFSET,
137 TLS_CIPHER_AES_GCM_128_SALT_SIZE);
138 memcpy(tcp->connect->send_crypto_info.rec_seq,
139 key_meterial + CRYPTO_SEQ_OFFSET,
140 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
141 memset(new_key, 0, HMDFS_KEY_SIZE);
142 memset(key_meterial, 0, HMDFS_KEY_SIZE);
143
144 tls_crypto_set_key(tcp->connect, 1);
145 return 0;
146 }
147
tls_set_rx(struct tcp_handle * tcp)148 static int tls_set_rx(struct tcp_handle *tcp)
149 {
150 int ret = 0;
151 u8 new_key[HMDFS_KEY_SIZE];
152 u8 key_meterial[HMDFS_KEY_SIZE];
153
154 ret = update_key(tcp->connect->recv_key, new_key, HKDF_TYPE_REKEY);
155 if (ret < 0)
156 return ret;
157 memcpy(tcp->connect->recv_key, new_key, HMDFS_KEY_SIZE);
158 ret = update_key(tcp->connect->recv_key, key_meterial, HKDF_TYPE_IV);
159 if (ret < 0)
160 return ret;
161
162 memcpy(tcp->connect->recv_crypto_info.key, tcp->connect->recv_key,
163 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
164 memcpy(tcp->connect->recv_crypto_info.iv,
165 key_meterial + CRYPTO_IV_OFFSET, TLS_CIPHER_AES_GCM_128_IV_SIZE);
166 memcpy(tcp->connect->recv_crypto_info.salt,
167 key_meterial + CRYPTO_SALT_OFFSET,
168 TLS_CIPHER_AES_GCM_128_SALT_SIZE);
169 memcpy(tcp->connect->recv_crypto_info.rec_seq,
170 key_meterial + CRYPTO_SEQ_OFFSET,
171 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
172 memset(new_key, 0, HMDFS_KEY_SIZE);
173 memset(key_meterial, 0, HMDFS_KEY_SIZE);
174 tls_crypto_set_key(tcp->connect, 0);
175 return 0;
176 }
177
set_crypto_info(struct connection * conn_impl,int set_type)178 int set_crypto_info(struct connection *conn_impl, int set_type)
179 {
180 int ret = 0;
181 struct tcp_handle *tcp =
182 (struct tcp_handle *)(conn_impl->connect_handle);
183 if (!tcp)
184 return -EINVAL;
185
186 if (set_type == SET_CRYPTO_SEND) {
187 ret = tls_set_tx(tcp);
188 if (ret) {
189 hmdfs_err("tls set tx fail");
190 return ret;
191 }
192 }
193 if (set_type == SET_CRYPTO_RECV) {
194 ret = tls_set_rx(tcp);
195 if (ret) {
196 hmdfs_err("tls set rx fail");
197 return ret;
198 }
199 }
200 hmdfs_info("KTLS setting success");
201 return ret;
202 }
203
hmac_sha256(u8 * key,u8 key_len,char * info,u8 info_len,u8 * output)204 static int hmac_sha256(u8 *key, u8 key_len, char *info, u8 info_len, u8 *output)
205 {
206 struct crypto_shash *tfm = NULL;
207 struct shash_desc *shash = NULL;
208 int ret = 0;
209
210 if (!key)
211 return -EINVAL;
212
213 tfm = crypto_alloc_shash("hmac(sha256)", 0, 0);
214 if (IS_ERR(tfm)) {
215 hmdfs_err("crypto_alloc_ahash failed: err %ld", PTR_ERR(tfm));
216 return PTR_ERR(tfm);
217 }
218
219 ret = crypto_shash_setkey(tfm, key, key_len);
220 if (ret) {
221 hmdfs_err("crypto_ahash_setkey failed: err %d", ret);
222 goto failed;
223 }
224
225 shash = kzalloc(sizeof(*shash) + crypto_shash_descsize(tfm),
226 GFP_KERNEL);
227 if (!shash) {
228 ret = -ENOMEM;
229 goto failed;
230 }
231
232 shash->tfm = tfm;
233
234 ret = crypto_shash_digest(shash, info, info_len, output);
235
236 kfree(shash);
237
238 failed:
239 crypto_free_shash(tfm);
240 return ret;
241 }
242
243 static const char *const g_key_lable[] = { "ktls key initiator",
244 "ktls key accepter",
245 "ktls key update", "ktls iv&salt" };
246 static const int g_key_lable_len[] = { 18, 17, 15, 12 };
247
update_key(__u8 * old_key,__u8 * new_key,int type)248 int update_key(__u8 *old_key, __u8 *new_key, int type)
249 {
250 int ret = 0;
251 char lable[MAX_LABLE_SIZE];
252 u8 lable_size;
253
254 lable_size = g_key_lable_len[type] + sizeof(u16) + sizeof(char);
255 *((u16 *)lable) = HMDFS_KEY_SIZE;
256 memcpy(lable + sizeof(u16), g_key_lable[type], g_key_lable_len[type]);
257 *(lable + sizeof(u16) + g_key_lable_len[type]) = 0x01;
258 ret = hmac_sha256(old_key, HMDFS_KEY_SIZE, lable, lable_size, new_key);
259 if (ret < 0)
260 hmdfs_err("hmac sha256 error");
261 return ret;
262 }
263