• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31 
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40 
41 #include "trace.h"
42 
43 /* device_offload_lock is used to synchronize tls_dev_add
44  * against NETDEV_DOWN notifications.
45  */
46 static DECLARE_RWSEM(device_offload_lock);
47 
48 static struct workqueue_struct *destruct_wq __read_mostly;
49 
50 static LIST_HEAD(tls_device_list);
51 static LIST_HEAD(tls_device_down_list);
52 static DEFINE_SPINLOCK(tls_device_lock);
53 
54 static struct page *dummy_page;
55 
tls_device_free_ctx(struct tls_context * ctx)56 static void tls_device_free_ctx(struct tls_context *ctx)
57 {
58 	if (ctx->tx_conf == TLS_HW) {
59 		kfree(tls_offload_ctx_tx(ctx));
60 		kfree(ctx->tx.rec_seq);
61 		kfree(ctx->tx.iv);
62 	}
63 
64 	if (ctx->rx_conf == TLS_HW)
65 		kfree(tls_offload_ctx_rx(ctx));
66 
67 	tls_ctx_free(NULL, ctx);
68 }
69 
tls_device_tx_del_task(struct work_struct * work)70 static void tls_device_tx_del_task(struct work_struct *work)
71 {
72 	struct tls_offload_context_tx *offload_ctx =
73 		container_of(work, struct tls_offload_context_tx, destruct_work);
74 	struct tls_context *ctx = offload_ctx->ctx;
75 	struct net_device *netdev = ctx->netdev;
76 
77 	netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
78 	dev_put(netdev);
79 	ctx->netdev = NULL;
80 	tls_device_free_ctx(ctx);
81 }
82 
tls_device_queue_ctx_destruction(struct tls_context * ctx)83 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
84 {
85 	unsigned long flags;
86 	bool async_cleanup;
87 
88 	spin_lock_irqsave(&tls_device_lock, flags);
89 	if (unlikely(!refcount_dec_and_test(&ctx->refcount))) {
90 		spin_unlock_irqrestore(&tls_device_lock, flags);
91 		return;
92 	}
93 
94 	list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
95 	async_cleanup = ctx->netdev && ctx->tx_conf == TLS_HW;
96 	if (async_cleanup) {
97 		struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
98 
99 		/* queue_work inside the spinlock
100 		 * to make sure tls_device_down waits for that work.
101 		 */
102 		queue_work(destruct_wq, &offload_ctx->destruct_work);
103 	}
104 	spin_unlock_irqrestore(&tls_device_lock, flags);
105 
106 	if (!async_cleanup)
107 		tls_device_free_ctx(ctx);
108 }
109 
110 /* We assume that the socket is already connected */
get_netdev_for_sock(struct sock * sk)111 static struct net_device *get_netdev_for_sock(struct sock *sk)
112 {
113 	struct dst_entry *dst = sk_dst_get(sk);
114 	struct net_device *netdev = NULL;
115 
116 	if (likely(dst)) {
117 		netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
118 		dev_hold(netdev);
119 	}
120 
121 	dst_release(dst);
122 
123 	return netdev;
124 }
125 
destroy_record(struct tls_record_info * record)126 static void destroy_record(struct tls_record_info *record)
127 {
128 	int i;
129 
130 	for (i = 0; i < record->num_frags; i++)
131 		__skb_frag_unref(&record->frags[i], false);
132 	kfree(record);
133 }
134 
delete_all_records(struct tls_offload_context_tx * offload_ctx)135 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
136 {
137 	struct tls_record_info *info, *temp;
138 
139 	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
140 		list_del(&info->list);
141 		destroy_record(info);
142 	}
143 
144 	offload_ctx->retransmit_hint = NULL;
145 }
146 
tls_icsk_clean_acked(struct sock * sk,u32 acked_seq)147 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
148 {
149 	struct tls_context *tls_ctx = tls_get_ctx(sk);
150 	struct tls_record_info *info, *temp;
151 	struct tls_offload_context_tx *ctx;
152 	u64 deleted_records = 0;
153 	unsigned long flags;
154 
155 	if (!tls_ctx)
156 		return;
157 
158 	ctx = tls_offload_ctx_tx(tls_ctx);
159 
160 	spin_lock_irqsave(&ctx->lock, flags);
161 	info = ctx->retransmit_hint;
162 	if (info && !before(acked_seq, info->end_seq))
163 		ctx->retransmit_hint = NULL;
164 
165 	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
166 		if (before(acked_seq, info->end_seq))
167 			break;
168 		list_del(&info->list);
169 
170 		destroy_record(info);
171 		deleted_records++;
172 	}
173 
174 	ctx->unacked_record_sn += deleted_records;
175 	spin_unlock_irqrestore(&ctx->lock, flags);
176 }
177 
178 /* At this point, there should be no references on this
179  * socket and no in-flight SKBs associated with this
180  * socket, so it is safe to free all the resources.
181  */
tls_device_sk_destruct(struct sock * sk)182 void tls_device_sk_destruct(struct sock *sk)
183 {
184 	struct tls_context *tls_ctx = tls_get_ctx(sk);
185 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
186 
187 	tls_ctx->sk_destruct(sk);
188 
189 	if (tls_ctx->tx_conf == TLS_HW) {
190 		if (ctx->open_record)
191 			destroy_record(ctx->open_record);
192 		delete_all_records(ctx);
193 		crypto_free_aead(ctx->aead_send);
194 		clean_acked_data_disable(inet_csk(sk));
195 	}
196 
197 	tls_device_queue_ctx_destruction(tls_ctx);
198 }
199 EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
200 
tls_device_free_resources_tx(struct sock * sk)201 void tls_device_free_resources_tx(struct sock *sk)
202 {
203 	struct tls_context *tls_ctx = tls_get_ctx(sk);
204 
205 	tls_free_partial_record(sk, tls_ctx);
206 }
207 
tls_offload_tx_resync_request(struct sock * sk,u32 got_seq,u32 exp_seq)208 void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
209 {
210 	struct tls_context *tls_ctx = tls_get_ctx(sk);
211 
212 	trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
213 	WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
214 }
215 EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
216 
tls_device_resync_tx(struct sock * sk,struct tls_context * tls_ctx,u32 seq)217 static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
218 				 u32 seq)
219 {
220 	struct net_device *netdev;
221 	struct sk_buff *skb;
222 	int err = 0;
223 	u8 *rcd_sn;
224 
225 	skb = tcp_write_queue_tail(sk);
226 	if (skb)
227 		TCP_SKB_CB(skb)->eor = 1;
228 
229 	rcd_sn = tls_ctx->tx.rec_seq;
230 
231 	trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
232 	down_read(&device_offload_lock);
233 	netdev = tls_ctx->netdev;
234 	if (netdev)
235 		err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
236 							 rcd_sn,
237 							 TLS_OFFLOAD_CTX_DIR_TX);
238 	up_read(&device_offload_lock);
239 	if (err)
240 		return;
241 
242 	clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
243 }
244 
tls_append_frag(struct tls_record_info * record,struct page_frag * pfrag,int size)245 static void tls_append_frag(struct tls_record_info *record,
246 			    struct page_frag *pfrag,
247 			    int size)
248 {
249 	skb_frag_t *frag;
250 
251 	frag = &record->frags[record->num_frags - 1];
252 	if (skb_frag_page(frag) == pfrag->page &&
253 	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
254 		skb_frag_size_add(frag, size);
255 	} else {
256 		++frag;
257 		__skb_frag_set_page(frag, pfrag->page);
258 		skb_frag_off_set(frag, pfrag->offset);
259 		skb_frag_size_set(frag, size);
260 		++record->num_frags;
261 		get_page(pfrag->page);
262 	}
263 
264 	pfrag->offset += size;
265 	record->len += size;
266 }
267 
tls_push_record(struct sock * sk,struct tls_context * ctx,struct tls_offload_context_tx * offload_ctx,struct tls_record_info * record,int flags)268 static int tls_push_record(struct sock *sk,
269 			   struct tls_context *ctx,
270 			   struct tls_offload_context_tx *offload_ctx,
271 			   struct tls_record_info *record,
272 			   int flags)
273 {
274 	struct tls_prot_info *prot = &ctx->prot_info;
275 	struct tcp_sock *tp = tcp_sk(sk);
276 	skb_frag_t *frag;
277 	int i;
278 
279 	record->end_seq = tp->write_seq + record->len;
280 	list_add_tail_rcu(&record->list, &offload_ctx->records_list);
281 	offload_ctx->open_record = NULL;
282 
283 	if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
284 		tls_device_resync_tx(sk, ctx, tp->write_seq);
285 
286 	tls_advance_record_sn(sk, prot, &ctx->tx);
287 
288 	for (i = 0; i < record->num_frags; i++) {
289 		frag = &record->frags[i];
290 		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
291 		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
292 			    skb_frag_size(frag), skb_frag_off(frag));
293 		sk_mem_charge(sk, skb_frag_size(frag));
294 		get_page(skb_frag_page(frag));
295 	}
296 	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
297 
298 	/* all ready, send */
299 	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
300 }
301 
tls_device_record_close(struct sock * sk,struct tls_context * ctx,struct tls_record_info * record,struct page_frag * pfrag,unsigned char record_type)302 static void tls_device_record_close(struct sock *sk,
303 				    struct tls_context *ctx,
304 				    struct tls_record_info *record,
305 				    struct page_frag *pfrag,
306 				    unsigned char record_type)
307 {
308 	struct tls_prot_info *prot = &ctx->prot_info;
309 	struct page_frag dummy_tag_frag;
310 
311 	/* append tag
312 	 * device will fill in the tag, we just need to append a placeholder
313 	 * use socket memory to improve coalescing (re-using a single buffer
314 	 * increases frag count)
315 	 * if we can't allocate memory now use the dummy page
316 	 */
317 	if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) &&
318 	    !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) {
319 		dummy_tag_frag.page = dummy_page;
320 		dummy_tag_frag.offset = 0;
321 		pfrag = &dummy_tag_frag;
322 	}
323 	tls_append_frag(record, pfrag, prot->tag_size);
324 
325 	/* fill prepend */
326 	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
327 			 record->len - prot->overhead_size,
328 			 record_type);
329 }
330 
tls_create_new_record(struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)331 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
332 				 struct page_frag *pfrag,
333 				 size_t prepend_size)
334 {
335 	struct tls_record_info *record;
336 	skb_frag_t *frag;
337 
338 	record = kmalloc(sizeof(*record), GFP_KERNEL);
339 	if (!record)
340 		return -ENOMEM;
341 
342 	frag = &record->frags[0];
343 	__skb_frag_set_page(frag, pfrag->page);
344 	skb_frag_off_set(frag, pfrag->offset);
345 	skb_frag_size_set(frag, prepend_size);
346 
347 	get_page(pfrag->page);
348 	pfrag->offset += prepend_size;
349 
350 	record->num_frags = 1;
351 	record->len = prepend_size;
352 	offload_ctx->open_record = record;
353 	return 0;
354 }
355 
tls_do_allocation(struct sock * sk,struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)356 static int tls_do_allocation(struct sock *sk,
357 			     struct tls_offload_context_tx *offload_ctx,
358 			     struct page_frag *pfrag,
359 			     size_t prepend_size)
360 {
361 	int ret;
362 
363 	if (!offload_ctx->open_record) {
364 		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
365 						   sk->sk_allocation))) {
366 			READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
367 			sk_stream_moderate_sndbuf(sk);
368 			return -ENOMEM;
369 		}
370 
371 		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
372 		if (ret)
373 			return ret;
374 
375 		if (pfrag->size > pfrag->offset)
376 			return 0;
377 	}
378 
379 	if (!sk_page_frag_refill(sk, pfrag))
380 		return -ENOMEM;
381 
382 	return 0;
383 }
384 
tls_device_copy_data(void * addr,size_t bytes,struct iov_iter * i)385 static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
386 {
387 	size_t pre_copy, nocache;
388 
389 	pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
390 	if (pre_copy) {
391 		pre_copy = min(pre_copy, bytes);
392 		if (copy_from_iter(addr, pre_copy, i) != pre_copy)
393 			return -EFAULT;
394 		bytes -= pre_copy;
395 		addr += pre_copy;
396 	}
397 
398 	nocache = round_down(bytes, SMP_CACHE_BYTES);
399 	if (copy_from_iter_nocache(addr, nocache, i) != nocache)
400 		return -EFAULT;
401 	bytes -= nocache;
402 	addr += nocache;
403 
404 	if (bytes && copy_from_iter(addr, bytes, i) != bytes)
405 		return -EFAULT;
406 
407 	return 0;
408 }
409 
tls_push_data(struct sock * sk,struct iov_iter * msg_iter,size_t size,int flags,unsigned char record_type)410 static int tls_push_data(struct sock *sk,
411 			 struct iov_iter *msg_iter,
412 			 size_t size, int flags,
413 			 unsigned char record_type)
414 {
415 	struct tls_context *tls_ctx = tls_get_ctx(sk);
416 	struct tls_prot_info *prot = &tls_ctx->prot_info;
417 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
418 	struct tls_record_info *record;
419 	int tls_push_record_flags;
420 	struct page_frag *pfrag;
421 	size_t orig_size = size;
422 	u32 max_open_record_len;
423 	bool more = false;
424 	bool done = false;
425 	int copy, rc = 0;
426 	long timeo;
427 
428 	if (flags &
429 	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
430 		return -EOPNOTSUPP;
431 
432 	if (unlikely(sk->sk_err))
433 		return -sk->sk_err;
434 
435 	flags |= MSG_SENDPAGE_DECRYPTED;
436 	tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
437 
438 	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
439 	if (tls_is_partially_sent_record(tls_ctx)) {
440 		rc = tls_push_partial_record(sk, tls_ctx, flags);
441 		if (rc < 0)
442 			return rc;
443 	}
444 
445 	pfrag = sk_page_frag(sk);
446 
447 	/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
448 	 * we need to leave room for an authentication tag.
449 	 */
450 	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
451 			      prot->prepend_size;
452 	do {
453 		rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
454 		if (unlikely(rc)) {
455 			rc = sk_stream_wait_memory(sk, &timeo);
456 			if (!rc)
457 				continue;
458 
459 			record = ctx->open_record;
460 			if (!record)
461 				break;
462 handle_error:
463 			if (record_type != TLS_RECORD_TYPE_DATA) {
464 				/* avoid sending partial
465 				 * record with type !=
466 				 * application_data
467 				 */
468 				size = orig_size;
469 				destroy_record(record);
470 				ctx->open_record = NULL;
471 			} else if (record->len > prot->prepend_size) {
472 				goto last_record;
473 			}
474 
475 			break;
476 		}
477 
478 		record = ctx->open_record;
479 		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
480 		copy = min_t(size_t, copy, (max_open_record_len - record->len));
481 
482 		if (copy) {
483 			rc = tls_device_copy_data(page_address(pfrag->page) +
484 						  pfrag->offset, copy, msg_iter);
485 			if (rc)
486 				goto handle_error;
487 			tls_append_frag(record, pfrag, copy);
488 		}
489 
490 		size -= copy;
491 		if (!size) {
492 last_record:
493 			tls_push_record_flags = flags;
494 			if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) {
495 				more = true;
496 				break;
497 			}
498 
499 			done = true;
500 		}
501 
502 		if (done || record->len >= max_open_record_len ||
503 		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
504 			tls_device_record_close(sk, tls_ctx, record,
505 						pfrag, record_type);
506 
507 			rc = tls_push_record(sk,
508 					     tls_ctx,
509 					     ctx,
510 					     record,
511 					     tls_push_record_flags);
512 			if (rc < 0)
513 				break;
514 		}
515 	} while (!done);
516 
517 	tls_ctx->pending_open_record_frags = more;
518 
519 	if (orig_size - size > 0)
520 		rc = orig_size - size;
521 
522 	return rc;
523 }
524 
tls_device_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)525 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
526 {
527 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
528 	struct tls_context *tls_ctx = tls_get_ctx(sk);
529 	int rc;
530 
531 	mutex_lock(&tls_ctx->tx_lock);
532 	lock_sock(sk);
533 
534 	if (unlikely(msg->msg_controllen)) {
535 		rc = tls_proccess_cmsg(sk, msg, &record_type);
536 		if (rc)
537 			goto out;
538 	}
539 
540 	rc = tls_push_data(sk, &msg->msg_iter, size,
541 			   msg->msg_flags, record_type);
542 
543 out:
544 	release_sock(sk);
545 	mutex_unlock(&tls_ctx->tx_lock);
546 	return rc;
547 }
548 
tls_device_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)549 int tls_device_sendpage(struct sock *sk, struct page *page,
550 			int offset, size_t size, int flags)
551 {
552 	struct tls_context *tls_ctx = tls_get_ctx(sk);
553 	struct iov_iter	msg_iter;
554 	char *kaddr;
555 	struct kvec iov;
556 	int rc;
557 
558 	if (flags & MSG_SENDPAGE_NOTLAST)
559 		flags |= MSG_MORE;
560 
561 	mutex_lock(&tls_ctx->tx_lock);
562 	lock_sock(sk);
563 
564 	if (flags & MSG_OOB) {
565 		rc = -EOPNOTSUPP;
566 		goto out;
567 	}
568 
569 	kaddr = kmap(page);
570 	iov.iov_base = kaddr + offset;
571 	iov.iov_len = size;
572 	iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
573 	rc = tls_push_data(sk, &msg_iter, size,
574 			   flags, TLS_RECORD_TYPE_DATA);
575 	kunmap(page);
576 
577 out:
578 	release_sock(sk);
579 	mutex_unlock(&tls_ctx->tx_lock);
580 	return rc;
581 }
582 
tls_get_record(struct tls_offload_context_tx * context,u32 seq,u64 * p_record_sn)583 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
584 				       u32 seq, u64 *p_record_sn)
585 {
586 	u64 record_sn = context->hint_record_sn;
587 	struct tls_record_info *info, *last;
588 
589 	info = context->retransmit_hint;
590 	if (!info ||
591 	    before(seq, info->end_seq - info->len)) {
592 		/* if retransmit_hint is irrelevant start
593 		 * from the beginning of the list
594 		 */
595 		info = list_first_entry_or_null(&context->records_list,
596 						struct tls_record_info, list);
597 		if (!info)
598 			return NULL;
599 		/* send the start_marker record if seq number is before the
600 		 * tls offload start marker sequence number. This record is
601 		 * required to handle TCP packets which are before TLS offload
602 		 * started.
603 		 *  And if it's not start marker, look if this seq number
604 		 * belongs to the list.
605 		 */
606 		if (likely(!tls_record_is_start_marker(info))) {
607 			/* we have the first record, get the last record to see
608 			 * if this seq number belongs to the list.
609 			 */
610 			last = list_last_entry(&context->records_list,
611 					       struct tls_record_info, list);
612 
613 			if (!between(seq, tls_record_start_seq(info),
614 				     last->end_seq))
615 				return NULL;
616 		}
617 		record_sn = context->unacked_record_sn;
618 	}
619 
620 	/* We just need the _rcu for the READ_ONCE() */
621 	rcu_read_lock();
622 	list_for_each_entry_from_rcu(info, &context->records_list, list) {
623 		if (before(seq, info->end_seq)) {
624 			if (!context->retransmit_hint ||
625 			    after(info->end_seq,
626 				  context->retransmit_hint->end_seq)) {
627 				context->hint_record_sn = record_sn;
628 				context->retransmit_hint = info;
629 			}
630 			*p_record_sn = record_sn;
631 			goto exit_rcu_unlock;
632 		}
633 		record_sn++;
634 	}
635 	info = NULL;
636 
637 exit_rcu_unlock:
638 	rcu_read_unlock();
639 	return info;
640 }
641 EXPORT_SYMBOL(tls_get_record);
642 
tls_device_push_pending_record(struct sock * sk,int flags)643 static int tls_device_push_pending_record(struct sock *sk, int flags)
644 {
645 	struct iov_iter	msg_iter;
646 
647 	iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
648 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
649 }
650 
tls_device_write_space(struct sock * sk,struct tls_context * ctx)651 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
652 {
653 	if (tls_is_partially_sent_record(ctx)) {
654 		gfp_t sk_allocation = sk->sk_allocation;
655 
656 		WARN_ON_ONCE(sk->sk_write_pending);
657 
658 		sk->sk_allocation = GFP_ATOMIC;
659 		tls_push_partial_record(sk, ctx,
660 					MSG_DONTWAIT | MSG_NOSIGNAL |
661 					MSG_SENDPAGE_DECRYPTED);
662 		sk->sk_allocation = sk_allocation;
663 	}
664 }
665 
tls_device_resync_rx(struct tls_context * tls_ctx,struct sock * sk,u32 seq,u8 * rcd_sn)666 static void tls_device_resync_rx(struct tls_context *tls_ctx,
667 				 struct sock *sk, u32 seq, u8 *rcd_sn)
668 {
669 	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
670 	struct net_device *netdev;
671 
672 	trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
673 	rcu_read_lock();
674 	netdev = READ_ONCE(tls_ctx->netdev);
675 	if (netdev)
676 		netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
677 						   TLS_OFFLOAD_CTX_DIR_RX);
678 	rcu_read_unlock();
679 	TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
680 }
681 
682 static bool
tls_device_rx_resync_async(struct tls_offload_resync_async * resync_async,s64 resync_req,u32 * seq,u16 * rcd_delta)683 tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
684 			   s64 resync_req, u32 *seq, u16 *rcd_delta)
685 {
686 	u32 is_async = resync_req & RESYNC_REQ_ASYNC;
687 	u32 req_seq = resync_req >> 32;
688 	u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
689 	u16 i;
690 
691 	*rcd_delta = 0;
692 
693 	if (is_async) {
694 		/* shouldn't get to wraparound:
695 		 * too long in async stage, something bad happened
696 		 */
697 		if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
698 			return false;
699 
700 		/* asynchronous stage: log all headers seq such that
701 		 * req_seq <= seq <= end_seq, and wait for real resync request
702 		 */
703 		if (before(*seq, req_seq))
704 			return false;
705 		if (!after(*seq, req_end) &&
706 		    resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
707 			resync_async->log[resync_async->loglen++] = *seq;
708 
709 		resync_async->rcd_delta++;
710 
711 		return false;
712 	}
713 
714 	/* synchronous stage: check against the logged entries and
715 	 * proceed to check the next entries if no match was found
716 	 */
717 	for (i = 0; i < resync_async->loglen; i++)
718 		if (req_seq == resync_async->log[i] &&
719 		    atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
720 			*rcd_delta = resync_async->rcd_delta - i;
721 			*seq = req_seq;
722 			resync_async->loglen = 0;
723 			resync_async->rcd_delta = 0;
724 			return true;
725 		}
726 
727 	resync_async->loglen = 0;
728 	resync_async->rcd_delta = 0;
729 
730 	if (req_seq == *seq &&
731 	    atomic64_try_cmpxchg(&resync_async->req,
732 				 &resync_req, 0))
733 		return true;
734 
735 	return false;
736 }
737 
tls_device_rx_resync_new_rec(struct sock * sk,u32 rcd_len,u32 seq)738 void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
739 {
740 	struct tls_context *tls_ctx = tls_get_ctx(sk);
741 	struct tls_offload_context_rx *rx_ctx;
742 	u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
743 	u32 sock_data, is_req_pending;
744 	struct tls_prot_info *prot;
745 	s64 resync_req;
746 	u16 rcd_delta;
747 	u32 req_seq;
748 
749 	if (tls_ctx->rx_conf != TLS_HW)
750 		return;
751 	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
752 		return;
753 
754 	prot = &tls_ctx->prot_info;
755 	rx_ctx = tls_offload_ctx_rx(tls_ctx);
756 	memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
757 
758 	switch (rx_ctx->resync_type) {
759 	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
760 		resync_req = atomic64_read(&rx_ctx->resync_req);
761 		req_seq = resync_req >> 32;
762 		seq += TLS_HEADER_SIZE - 1;
763 		is_req_pending = resync_req;
764 
765 		if (likely(!is_req_pending) || req_seq != seq ||
766 		    !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
767 			return;
768 		break;
769 	case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
770 		if (likely(!rx_ctx->resync_nh_do_now))
771 			return;
772 
773 		/* head of next rec is already in, note that the sock_inq will
774 		 * include the currently parsed message when called from parser
775 		 */
776 		sock_data = tcp_inq(sk);
777 		if (sock_data > rcd_len) {
778 			trace_tls_device_rx_resync_nh_delay(sk, sock_data,
779 							    rcd_len);
780 			return;
781 		}
782 
783 		rx_ctx->resync_nh_do_now = 0;
784 		seq += rcd_len;
785 		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
786 		break;
787 	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
788 		resync_req = atomic64_read(&rx_ctx->resync_async->req);
789 		is_req_pending = resync_req;
790 		if (likely(!is_req_pending))
791 			return;
792 
793 		if (!tls_device_rx_resync_async(rx_ctx->resync_async,
794 						resync_req, &seq, &rcd_delta))
795 			return;
796 		tls_bigint_subtract(rcd_sn, rcd_delta);
797 		break;
798 	}
799 
800 	tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
801 }
802 
tls_device_core_ctrl_rx_resync(struct tls_context * tls_ctx,struct tls_offload_context_rx * ctx,struct sock * sk,struct sk_buff * skb)803 static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
804 					   struct tls_offload_context_rx *ctx,
805 					   struct sock *sk, struct sk_buff *skb)
806 {
807 	struct strp_msg *rxm;
808 
809 	/* device will request resyncs by itself based on stream scan */
810 	if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
811 		return;
812 	/* already scheduled */
813 	if (ctx->resync_nh_do_now)
814 		return;
815 	/* seen decrypted fragments since last fully-failed record */
816 	if (ctx->resync_nh_reset) {
817 		ctx->resync_nh_reset = 0;
818 		ctx->resync_nh.decrypted_failed = 1;
819 		ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
820 		return;
821 	}
822 
823 	if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
824 		return;
825 
826 	/* doing resync, bump the next target in case it fails */
827 	if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
828 		ctx->resync_nh.decrypted_tgt *= 2;
829 	else
830 		ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
831 
832 	rxm = strp_msg(skb);
833 
834 	/* head of next rec is already in, parser will sync for us */
835 	if (tcp_inq(sk) > rxm->full_len) {
836 		trace_tls_device_rx_resync_nh_schedule(sk);
837 		ctx->resync_nh_do_now = 1;
838 	} else {
839 		struct tls_prot_info *prot = &tls_ctx->prot_info;
840 		u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
841 
842 		memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
843 		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
844 
845 		tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
846 				     rcd_sn);
847 	}
848 }
849 
tls_device_reencrypt(struct sock * sk,struct sk_buff * skb)850 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
851 {
852 	struct strp_msg *rxm = strp_msg(skb);
853 	int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
854 	struct sk_buff *skb_iter, *unused;
855 	struct scatterlist sg[1];
856 	char *orig_buf, *buf;
857 
858 	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
859 			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
860 	if (!orig_buf)
861 		return -ENOMEM;
862 	buf = orig_buf;
863 
864 	nsg = skb_cow_data(skb, 0, &unused);
865 	if (unlikely(nsg < 0)) {
866 		err = nsg;
867 		goto free_buf;
868 	}
869 
870 	sg_init_table(sg, 1);
871 	sg_set_buf(&sg[0], buf,
872 		   rxm->full_len + TLS_HEADER_SIZE +
873 		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
874 	err = skb_copy_bits(skb, offset, buf,
875 			    TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
876 	if (err)
877 		goto free_buf;
878 
879 	/* We are interested only in the decrypted data not the auth */
880 	err = decrypt_skb(sk, skb, sg);
881 	if (err != -EBADMSG)
882 		goto free_buf;
883 	else
884 		err = 0;
885 
886 	data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
887 
888 	if (skb_pagelen(skb) > offset) {
889 		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
890 
891 		if (skb->decrypted) {
892 			err = skb_store_bits(skb, offset, buf, copy);
893 			if (err)
894 				goto free_buf;
895 		}
896 
897 		offset += copy;
898 		buf += copy;
899 	}
900 
901 	pos = skb_pagelen(skb);
902 	skb_walk_frags(skb, skb_iter) {
903 		int frag_pos;
904 
905 		/* Practically all frags must belong to msg if reencrypt
906 		 * is needed with current strparser and coalescing logic,
907 		 * but strparser may "get optimized", so let's be safe.
908 		 */
909 		if (pos + skb_iter->len <= offset)
910 			goto done_with_frag;
911 		if (pos >= data_len + rxm->offset)
912 			break;
913 
914 		frag_pos = offset - pos;
915 		copy = min_t(int, skb_iter->len - frag_pos,
916 			     data_len + rxm->offset - offset);
917 
918 		if (skb_iter->decrypted) {
919 			err = skb_store_bits(skb_iter, frag_pos, buf, copy);
920 			if (err)
921 				goto free_buf;
922 		}
923 
924 		offset += copy;
925 		buf += copy;
926 done_with_frag:
927 		pos += skb_iter->len;
928 	}
929 
930 free_buf:
931 	kfree(orig_buf);
932 	return err;
933 }
934 
tls_device_decrypted(struct sock * sk,struct tls_context * tls_ctx,struct sk_buff * skb,struct strp_msg * rxm)935 int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
936 			 struct sk_buff *skb, struct strp_msg *rxm)
937 {
938 	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
939 	int is_decrypted = skb->decrypted;
940 	int is_encrypted = !is_decrypted;
941 	struct sk_buff *skb_iter;
942 
943 	/* Check if all the data is decrypted already */
944 	skb_walk_frags(skb, skb_iter) {
945 		is_decrypted &= skb_iter->decrypted;
946 		is_encrypted &= !skb_iter->decrypted;
947 	}
948 
949 	trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
950 				   tls_ctx->rx.rec_seq, rxm->full_len,
951 				   is_encrypted, is_decrypted);
952 
953 	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
954 		if (likely(is_encrypted || is_decrypted))
955 			return is_decrypted;
956 
957 		/* After tls_device_down disables the offload, the next SKB will
958 		 * likely have initial fragments decrypted, and final ones not
959 		 * decrypted. We need to reencrypt that single SKB.
960 		 */
961 		return tls_device_reencrypt(sk, skb);
962 	}
963 
964 	/* Return immediately if the record is either entirely plaintext or
965 	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
966 	 * record.
967 	 */
968 	if (is_decrypted) {
969 		ctx->resync_nh_reset = 1;
970 		return is_decrypted;
971 	}
972 	if (is_encrypted) {
973 		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
974 		return 0;
975 	}
976 
977 	ctx->resync_nh_reset = 1;
978 	return tls_device_reencrypt(sk, skb);
979 }
980 
tls_device_attach(struct tls_context * ctx,struct sock * sk,struct net_device * netdev)981 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
982 			      struct net_device *netdev)
983 {
984 	if (sk->sk_destruct != tls_device_sk_destruct) {
985 		refcount_set(&ctx->refcount, 1);
986 		dev_hold(netdev);
987 		ctx->netdev = netdev;
988 		spin_lock_irq(&tls_device_lock);
989 		list_add_tail(&ctx->list, &tls_device_list);
990 		spin_unlock_irq(&tls_device_lock);
991 
992 		ctx->sk_destruct = sk->sk_destruct;
993 		smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
994 	}
995 }
996 
tls_set_device_offload(struct sock * sk,struct tls_context * ctx)997 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
998 {
999 	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
1000 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1001 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1002 	struct tls_record_info *start_marker_record;
1003 	struct tls_offload_context_tx *offload_ctx;
1004 	struct tls_crypto_info *crypto_info;
1005 	struct net_device *netdev;
1006 	char *iv, *rec_seq;
1007 	struct sk_buff *skb;
1008 	__be64 rcd_sn;
1009 	int rc;
1010 
1011 	if (!ctx)
1012 		return -EINVAL;
1013 
1014 	if (ctx->priv_ctx_tx)
1015 		return -EEXIST;
1016 
1017 	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
1018 	if (!start_marker_record)
1019 		return -ENOMEM;
1020 
1021 	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
1022 	if (!offload_ctx) {
1023 		rc = -ENOMEM;
1024 		goto free_marker_record;
1025 	}
1026 
1027 	crypto_info = &ctx->crypto_send.info;
1028 	if (crypto_info->version != TLS_1_2_VERSION) {
1029 		rc = -EOPNOTSUPP;
1030 		goto free_offload_ctx;
1031 	}
1032 
1033 	switch (crypto_info->cipher_type) {
1034 	case TLS_CIPHER_AES_GCM_128:
1035 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1036 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1037 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1038 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1039 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1040 		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
1041 		rec_seq =
1042 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1043 		break;
1044 	default:
1045 		rc = -EINVAL;
1046 		goto free_offload_ctx;
1047 	}
1048 
1049 	/* Sanity-check the rec_seq_size for stack allocations */
1050 	if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
1051 		rc = -EINVAL;
1052 		goto free_offload_ctx;
1053 	}
1054 
1055 	prot->version = crypto_info->version;
1056 	prot->cipher_type = crypto_info->cipher_type;
1057 	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
1058 	prot->tag_size = tag_size;
1059 	prot->overhead_size = prot->prepend_size + prot->tag_size;
1060 	prot->iv_size = iv_size;
1061 	prot->salt_size = salt_size;
1062 	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1063 			     GFP_KERNEL);
1064 	if (!ctx->tx.iv) {
1065 		rc = -ENOMEM;
1066 		goto free_offload_ctx;
1067 	}
1068 
1069 	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1070 
1071 	prot->rec_seq_size = rec_seq_size;
1072 	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1073 	if (!ctx->tx.rec_seq) {
1074 		rc = -ENOMEM;
1075 		goto free_iv;
1076 	}
1077 
1078 	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
1079 	if (rc)
1080 		goto free_rec_seq;
1081 
1082 	/* start at rec_seq - 1 to account for the start marker record */
1083 	memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
1084 	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
1085 
1086 	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
1087 	start_marker_record->len = 0;
1088 	start_marker_record->num_frags = 0;
1089 
1090 	INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task);
1091 	offload_ctx->ctx = ctx;
1092 
1093 	INIT_LIST_HEAD(&offload_ctx->records_list);
1094 	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
1095 	spin_lock_init(&offload_ctx->lock);
1096 	sg_init_table(offload_ctx->sg_tx_data,
1097 		      ARRAY_SIZE(offload_ctx->sg_tx_data));
1098 
1099 	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
1100 	ctx->push_pending_record = tls_device_push_pending_record;
1101 
1102 	/* TLS offload is greatly simplified if we don't send
1103 	 * SKBs where only part of the payload needs to be encrypted.
1104 	 * So mark the last skb in the write queue as end of record.
1105 	 */
1106 	skb = tcp_write_queue_tail(sk);
1107 	if (skb)
1108 		TCP_SKB_CB(skb)->eor = 1;
1109 
1110 	netdev = get_netdev_for_sock(sk);
1111 	if (!netdev) {
1112 		pr_err_ratelimited("%s: netdev not found\n", __func__);
1113 		rc = -EINVAL;
1114 		goto disable_cad;
1115 	}
1116 
1117 	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
1118 		rc = -EOPNOTSUPP;
1119 		goto release_netdev;
1120 	}
1121 
1122 	/* Avoid offloading if the device is down
1123 	 * We don't want to offload new flows after
1124 	 * the NETDEV_DOWN event
1125 	 *
1126 	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1127 	 * handler thus protecting from the device going down before
1128 	 * ctx was added to tls_device_list.
1129 	 */
1130 	down_read(&device_offload_lock);
1131 	if (!(netdev->flags & IFF_UP)) {
1132 		rc = -EINVAL;
1133 		goto release_lock;
1134 	}
1135 
1136 	ctx->priv_ctx_tx = offload_ctx;
1137 	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
1138 					     &ctx->crypto_send.info,
1139 					     tcp_sk(sk)->write_seq);
1140 	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
1141 				     tcp_sk(sk)->write_seq, rec_seq, rc);
1142 	if (rc)
1143 		goto release_lock;
1144 
1145 	tls_device_attach(ctx, sk, netdev);
1146 	up_read(&device_offload_lock);
1147 
1148 	/* following this assignment tls_is_sk_tx_device_offloaded
1149 	 * will return true and the context might be accessed
1150 	 * by the netdev's xmit function.
1151 	 */
1152 	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
1153 	dev_put(netdev);
1154 
1155 	return 0;
1156 
1157 release_lock:
1158 	up_read(&device_offload_lock);
1159 release_netdev:
1160 	dev_put(netdev);
1161 disable_cad:
1162 	clean_acked_data_disable(inet_csk(sk));
1163 	crypto_free_aead(offload_ctx->aead_send);
1164 free_rec_seq:
1165 	kfree(ctx->tx.rec_seq);
1166 free_iv:
1167 	kfree(ctx->tx.iv);
1168 free_offload_ctx:
1169 	kfree(offload_ctx);
1170 	ctx->priv_ctx_tx = NULL;
1171 free_marker_record:
1172 	kfree(start_marker_record);
1173 	return rc;
1174 }
1175 
tls_set_device_offload_rx(struct sock * sk,struct tls_context * ctx)1176 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
1177 {
1178 	struct tls12_crypto_info_aes_gcm_128 *info;
1179 	struct tls_offload_context_rx *context;
1180 	struct net_device *netdev;
1181 	int rc = 0;
1182 
1183 	if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1184 		return -EOPNOTSUPP;
1185 
1186 	netdev = get_netdev_for_sock(sk);
1187 	if (!netdev) {
1188 		pr_err_ratelimited("%s: netdev not found\n", __func__);
1189 		return -EINVAL;
1190 	}
1191 
1192 	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1193 		rc = -EOPNOTSUPP;
1194 		goto release_netdev;
1195 	}
1196 
1197 	/* Avoid offloading if the device is down
1198 	 * We don't want to offload new flows after
1199 	 * the NETDEV_DOWN event
1200 	 *
1201 	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1202 	 * handler thus protecting from the device going down before
1203 	 * ctx was added to tls_device_list.
1204 	 */
1205 	down_read(&device_offload_lock);
1206 	if (!(netdev->flags & IFF_UP)) {
1207 		rc = -EINVAL;
1208 		goto release_lock;
1209 	}
1210 
1211 	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
1212 	if (!context) {
1213 		rc = -ENOMEM;
1214 		goto release_lock;
1215 	}
1216 	context->resync_nh_reset = 1;
1217 
1218 	ctx->priv_ctx_rx = context;
1219 	rc = tls_set_sw_offload(sk, ctx, 0);
1220 	if (rc)
1221 		goto release_ctx;
1222 
1223 	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1224 					     &ctx->crypto_recv.info,
1225 					     tcp_sk(sk)->copied_seq);
1226 	info = (void *)&ctx->crypto_recv.info;
1227 	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
1228 				     tcp_sk(sk)->copied_seq, info->rec_seq, rc);
1229 	if (rc)
1230 		goto free_sw_resources;
1231 
1232 	tls_device_attach(ctx, sk, netdev);
1233 	up_read(&device_offload_lock);
1234 
1235 	dev_put(netdev);
1236 
1237 	return 0;
1238 
1239 free_sw_resources:
1240 	up_read(&device_offload_lock);
1241 	tls_sw_free_resources_rx(sk);
1242 	down_read(&device_offload_lock);
1243 release_ctx:
1244 	ctx->priv_ctx_rx = NULL;
1245 release_lock:
1246 	up_read(&device_offload_lock);
1247 release_netdev:
1248 	dev_put(netdev);
1249 	return rc;
1250 }
1251 
tls_device_offload_cleanup_rx(struct sock * sk)1252 void tls_device_offload_cleanup_rx(struct sock *sk)
1253 {
1254 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1255 	struct net_device *netdev;
1256 
1257 	down_read(&device_offload_lock);
1258 	netdev = tls_ctx->netdev;
1259 	if (!netdev)
1260 		goto out;
1261 
1262 	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1263 					TLS_OFFLOAD_CTX_DIR_RX);
1264 
1265 	if (tls_ctx->tx_conf != TLS_HW) {
1266 		dev_put(netdev);
1267 		tls_ctx->netdev = NULL;
1268 	} else {
1269 		set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
1270 	}
1271 out:
1272 	up_read(&device_offload_lock);
1273 	tls_sw_release_resources_rx(sk);
1274 }
1275 
tls_device_down(struct net_device * netdev)1276 static int tls_device_down(struct net_device *netdev)
1277 {
1278 	struct tls_context *ctx, *tmp;
1279 	unsigned long flags;
1280 	LIST_HEAD(list);
1281 
1282 	/* Request a write lock to block new offload attempts */
1283 	down_write(&device_offload_lock);
1284 
1285 	spin_lock_irqsave(&tls_device_lock, flags);
1286 	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1287 		if (ctx->netdev != netdev ||
1288 		    !refcount_inc_not_zero(&ctx->refcount))
1289 			continue;
1290 
1291 		list_move(&ctx->list, &list);
1292 	}
1293 	spin_unlock_irqrestore(&tls_device_lock, flags);
1294 
1295 	list_for_each_entry_safe(ctx, tmp, &list, list)	{
1296 		/* Stop offloaded TX and switch to the fallback.
1297 		 * tls_is_sk_tx_device_offloaded will return false.
1298 		 */
1299 		WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
1300 
1301 		/* Stop the RX and TX resync.
1302 		 * tls_dev_resync must not be called after tls_dev_del.
1303 		 */
1304 		WRITE_ONCE(ctx->netdev, NULL);
1305 
1306 		/* Start skipping the RX resync logic completely. */
1307 		set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
1308 
1309 		/* Sync with inflight packets. After this point:
1310 		 * TX: no non-encrypted packets will be passed to the driver.
1311 		 * RX: resync requests from the driver will be ignored.
1312 		 */
1313 		synchronize_net();
1314 
1315 		/* Release the offload context on the driver side. */
1316 		if (ctx->tx_conf == TLS_HW)
1317 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1318 							TLS_OFFLOAD_CTX_DIR_TX);
1319 		if (ctx->rx_conf == TLS_HW &&
1320 		    !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
1321 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1322 							TLS_OFFLOAD_CTX_DIR_RX);
1323 
1324 		dev_put(netdev);
1325 
1326 		/* Move the context to a separate list for two reasons:
1327 		 * 1. When the context is deallocated, list_del is called.
1328 		 * 2. It's no longer an offloaded context, so we don't want to
1329 		 *    run offload-specific code on this context.
1330 		 */
1331 		spin_lock_irqsave(&tls_device_lock, flags);
1332 		list_move_tail(&ctx->list, &tls_device_down_list);
1333 		spin_unlock_irqrestore(&tls_device_lock, flags);
1334 
1335 		/* Device contexts for RX and TX will be freed in on sk_destruct
1336 		 * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
1337 		 * Now release the ref taken above.
1338 		 */
1339 		if (refcount_dec_and_test(&ctx->refcount)) {
1340 			/* sk_destruct ran after tls_device_down took a ref, and
1341 			 * it returned early. Complete the destruction here.
1342 			 */
1343 			list_del(&ctx->list);
1344 			tls_device_free_ctx(ctx);
1345 		}
1346 	}
1347 
1348 	up_write(&device_offload_lock);
1349 
1350 	flush_workqueue(destruct_wq);
1351 
1352 	return NOTIFY_DONE;
1353 }
1354 
tls_dev_event(struct notifier_block * this,unsigned long event,void * ptr)1355 static int tls_dev_event(struct notifier_block *this, unsigned long event,
1356 			 void *ptr)
1357 {
1358 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1359 
1360 	if (!dev->tlsdev_ops &&
1361 	    !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1362 		return NOTIFY_DONE;
1363 
1364 	switch (event) {
1365 	case NETDEV_REGISTER:
1366 	case NETDEV_FEAT_CHANGE:
1367 		if (netif_is_bond_master(dev))
1368 			return NOTIFY_DONE;
1369 		if ((dev->features & NETIF_F_HW_TLS_RX) &&
1370 		    !dev->tlsdev_ops->tls_dev_resync)
1371 			return NOTIFY_BAD;
1372 
1373 		if  (dev->tlsdev_ops &&
1374 		     dev->tlsdev_ops->tls_dev_add &&
1375 		     dev->tlsdev_ops->tls_dev_del)
1376 			return NOTIFY_DONE;
1377 		else
1378 			return NOTIFY_BAD;
1379 	case NETDEV_DOWN:
1380 		return tls_device_down(dev);
1381 	}
1382 	return NOTIFY_DONE;
1383 }
1384 
1385 static struct notifier_block tls_dev_notifier = {
1386 	.notifier_call	= tls_dev_event,
1387 };
1388 
tls_device_init(void)1389 int __init tls_device_init(void)
1390 {
1391 	int err;
1392 
1393 	dummy_page = alloc_page(GFP_KERNEL);
1394 	if (!dummy_page)
1395 		return -ENOMEM;
1396 
1397 	destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0);
1398 	if (!destruct_wq) {
1399 		err = -ENOMEM;
1400 		goto err_free_dummy;
1401 	}
1402 
1403 	err = register_netdevice_notifier(&tls_dev_notifier);
1404 	if (err)
1405 		goto err_destroy_wq;
1406 
1407 	return 0;
1408 
1409 err_destroy_wq:
1410 	destroy_workqueue(destruct_wq);
1411 err_free_dummy:
1412 	put_page(dummy_page);
1413 	return err;
1414 }
1415 
tls_device_cleanup(void)1416 void __exit tls_device_cleanup(void)
1417 {
1418 	unregister_netdevice_notifier(&tls_dev_notifier);
1419 	flush_workqueue(destruct_wq);
1420 	destroy_workqueue(destruct_wq);
1421 	clean_acked_data_flush();
1422 	put_page(dummy_page);
1423 }
1424