• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP)
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8 
9 #include <linux/if_arp.h>
10 #include <linux/net.h>
11 #include <linux/mctp.h>
12 #include <linux/module.h>
13 #include <linux/socket.h>
14 
15 #include <net/mctp.h>
16 #include <net/mctpdevice.h>
17 #include <net/sock.h>
18 
19 /* socket implementation */
20 
mctp_release(struct socket * sock)21 static int mctp_release(struct socket *sock)
22 {
23 	struct sock *sk = sock->sk;
24 
25 	if (sk) {
26 		sock->sk = NULL;
27 		sk->sk_prot->close(sk, 0);
28 	}
29 
30 	return 0;
31 }
32 
33 /* Generic sockaddr checks, padding checks only so far */
mctp_sockaddr_is_ok(const struct sockaddr_mctp * addr)34 static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr)
35 {
36 	return !addr->__smctp_pad0 && !addr->__smctp_pad1;
37 }
38 
mctp_bind(struct socket * sock,struct sockaddr * addr,int addrlen)39 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
40 {
41 	struct sock *sk = sock->sk;
42 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
43 	struct sockaddr_mctp *smctp;
44 	int rc;
45 
46 	if (addrlen < sizeof(*smctp))
47 		return -EINVAL;
48 
49 	if (addr->sa_family != AF_MCTP)
50 		return -EAFNOSUPPORT;
51 
52 	if (!capable(CAP_NET_BIND_SERVICE))
53 		return -EACCES;
54 
55 	/* it's a valid sockaddr for MCTP, cast and do protocol checks */
56 	smctp = (struct sockaddr_mctp *)addr;
57 
58 	if (!mctp_sockaddr_is_ok(smctp))
59 		return -EINVAL;
60 
61 	lock_sock(sk);
62 
63 	/* TODO: allow rebind */
64 	if (sk_hashed(sk)) {
65 		rc = -EADDRINUSE;
66 		goto out_release;
67 	}
68 	msk->bind_net = smctp->smctp_network;
69 	msk->bind_addr = smctp->smctp_addr.s_addr;
70 	msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
71 
72 	rc = sk->sk_prot->hash(sk);
73 
74 out_release:
75 	release_sock(sk);
76 
77 	return rc;
78 }
79 
mctp_sendmsg(struct socket * sock,struct msghdr * msg,size_t len)80 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
81 {
82 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
83 	const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
84 	int rc, addrlen = msg->msg_namelen;
85 	struct sock *sk = sock->sk;
86 	struct mctp_skb_cb *cb;
87 	struct mctp_route *rt;
88 	struct sk_buff *skb;
89 
90 	if (addr) {
91 		if (addrlen < sizeof(struct sockaddr_mctp))
92 			return -EINVAL;
93 		if (addr->smctp_family != AF_MCTP)
94 			return -EINVAL;
95 		if (!mctp_sockaddr_is_ok(addr))
96 			return -EINVAL;
97 		if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
98 			return -EINVAL;
99 
100 	} else {
101 		/* TODO: connect()ed sockets */
102 		return -EDESTADDRREQ;
103 	}
104 
105 	if (!capable(CAP_NET_RAW))
106 		return -EACCES;
107 
108 	if (addr->smctp_network == MCTP_NET_ANY)
109 		addr->smctp_network = mctp_default_net(sock_net(sk));
110 
111 	rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
112 			       addr->smctp_addr.s_addr);
113 	if (!rt)
114 		return -EHOSTUNREACH;
115 
116 	skb = sock_alloc_send_skb(sk, hlen + 1 + len,
117 				  msg->msg_flags & MSG_DONTWAIT, &rc);
118 	if (!skb)
119 		return rc;
120 
121 	skb_reserve(skb, hlen);
122 
123 	/* set type as fist byte in payload */
124 	*(u8 *)skb_put(skb, 1) = addr->smctp_type;
125 
126 	rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
127 	if (rc < 0) {
128 		kfree_skb(skb);
129 		return rc;
130 	}
131 
132 	/* set up cb */
133 	cb = __mctp_cb(skb);
134 	cb->net = addr->smctp_network;
135 
136 	rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
137 			       addr->smctp_tag);
138 
139 	return rc ? : len;
140 }
141 
mctp_recvmsg(struct socket * sock,struct msghdr * msg,size_t len,int flags)142 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
143 			int flags)
144 {
145 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
146 	struct sock *sk = sock->sk;
147 	struct sk_buff *skb;
148 	size_t msglen;
149 	u8 type;
150 	int rc;
151 
152 	if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
153 		return -EOPNOTSUPP;
154 
155 	skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
156 	if (!skb)
157 		return rc;
158 
159 	if (!skb->len) {
160 		rc = 0;
161 		goto out_free;
162 	}
163 
164 	/* extract message type, remove from data */
165 	type = *((u8 *)skb->data);
166 	msglen = skb->len - 1;
167 
168 	if (len < msglen)
169 		msg->msg_flags |= MSG_TRUNC;
170 	else
171 		len = msglen;
172 
173 	rc = skb_copy_datagram_msg(skb, 1, msg, len);
174 	if (rc < 0)
175 		goto out_free;
176 
177 	sock_recv_cmsgs(msg, sk, skb);
178 
179 	if (addr) {
180 		struct mctp_skb_cb *cb = mctp_cb(skb);
181 		/* TODO: expand mctp_skb_cb for header fields? */
182 		struct mctp_hdr *hdr = mctp_hdr(skb);
183 
184 		addr = msg->msg_name;
185 		addr->smctp_family = AF_MCTP;
186 		addr->__smctp_pad0 = 0;
187 		addr->smctp_network = cb->net;
188 		addr->smctp_addr.s_addr = hdr->src;
189 		addr->smctp_type = type;
190 		addr->smctp_tag = hdr->flags_seq_tag &
191 					(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
192 		addr->__smctp_pad1 = 0;
193 		msg->msg_namelen = sizeof(*addr);
194 	}
195 
196 	rc = len;
197 
198 	if (flags & MSG_TRUNC)
199 		rc = msglen;
200 
201 out_free:
202 	skb_free_datagram(sk, skb);
203 	return rc;
204 }
205 
mctp_setsockopt(struct socket * sock,int level,int optname,sockptr_t optval,unsigned int optlen)206 static int mctp_setsockopt(struct socket *sock, int level, int optname,
207 			   sockptr_t optval, unsigned int optlen)
208 {
209 	return -EINVAL;
210 }
211 
mctp_getsockopt(struct socket * sock,int level,int optname,char __user * optval,int __user * optlen)212 static int mctp_getsockopt(struct socket *sock, int level, int optname,
213 			   char __user *optval, int __user *optlen)
214 {
215 	return -EINVAL;
216 }
217 
218 static const struct proto_ops mctp_dgram_ops = {
219 	.family		= PF_MCTP,
220 	.release	= mctp_release,
221 	.bind		= mctp_bind,
222 	.connect	= sock_no_connect,
223 	.socketpair	= sock_no_socketpair,
224 	.accept		= sock_no_accept,
225 	.getname	= sock_no_getname,
226 	.poll		= datagram_poll,
227 	.ioctl		= sock_no_ioctl,
228 	.gettstamp	= sock_gettstamp,
229 	.listen		= sock_no_listen,
230 	.shutdown	= sock_no_shutdown,
231 	.setsockopt	= mctp_setsockopt,
232 	.getsockopt	= mctp_getsockopt,
233 	.sendmsg	= mctp_sendmsg,
234 	.recvmsg	= mctp_recvmsg,
235 	.mmap		= sock_no_mmap,
236 	.sendpage	= sock_no_sendpage,
237 };
238 
mctp_sk_init(struct sock * sk)239 static int mctp_sk_init(struct sock *sk)
240 {
241 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
242 
243 	INIT_HLIST_HEAD(&msk->keys);
244 	return 0;
245 }
246 
mctp_sk_close(struct sock * sk,long timeout)247 static void mctp_sk_close(struct sock *sk, long timeout)
248 {
249 	sk_common_release(sk);
250 }
251 
mctp_sk_hash(struct sock * sk)252 static int mctp_sk_hash(struct sock *sk)
253 {
254 	struct net *net = sock_net(sk);
255 
256 	mutex_lock(&net->mctp.bind_lock);
257 	sk_add_node_rcu(sk, &net->mctp.binds);
258 	mutex_unlock(&net->mctp.bind_lock);
259 
260 	return 0;
261 }
262 
mctp_sk_unhash(struct sock * sk)263 static void mctp_sk_unhash(struct sock *sk)
264 {
265 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
266 	struct net *net = sock_net(sk);
267 	struct mctp_sk_key *key;
268 	struct hlist_node *tmp;
269 	unsigned long flags;
270 
271 	/* remove from any type-based binds */
272 	mutex_lock(&net->mctp.bind_lock);
273 	sk_del_node_init_rcu(sk);
274 	mutex_unlock(&net->mctp.bind_lock);
275 
276 	/* remove tag allocations */
277 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
278 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
279 		hlist_del_rcu(&key->sklist);
280 		hlist_del_rcu(&key->hlist);
281 
282 		spin_lock(&key->reasm_lock);
283 		if (key->reasm_head)
284 			kfree_skb(key->reasm_head);
285 		key->reasm_head = NULL;
286 		key->reasm_dead = true;
287 		spin_unlock(&key->reasm_lock);
288 
289 		kfree_rcu(key, rcu);
290 	}
291 	sock_set_flag(sk, SOCK_DEAD);
292 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
293 
294 	synchronize_rcu();
295 }
296 
mctp_sk_destruct(struct sock * sk)297 static void mctp_sk_destruct(struct sock *sk)
298 {
299 	skb_queue_purge(&sk->sk_receive_queue);
300 }
301 
302 static struct proto mctp_proto = {
303 	.name		= "MCTP",
304 	.owner		= THIS_MODULE,
305 	.obj_size	= sizeof(struct mctp_sock),
306 	.init		= mctp_sk_init,
307 	.close		= mctp_sk_close,
308 	.hash		= mctp_sk_hash,
309 	.unhash		= mctp_sk_unhash,
310 };
311 
mctp_pf_create(struct net * net,struct socket * sock,int protocol,int kern)312 static int mctp_pf_create(struct net *net, struct socket *sock,
313 			  int protocol, int kern)
314 {
315 	const struct proto_ops *ops;
316 	struct proto *proto;
317 	struct sock *sk;
318 	int rc;
319 
320 	if (protocol)
321 		return -EPROTONOSUPPORT;
322 
323 	/* only datagram sockets are supported */
324 	if (sock->type != SOCK_DGRAM)
325 		return -ESOCKTNOSUPPORT;
326 
327 	proto = &mctp_proto;
328 	ops = &mctp_dgram_ops;
329 
330 	sock->state = SS_UNCONNECTED;
331 	sock->ops = ops;
332 
333 	sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
334 	if (!sk)
335 		return -ENOMEM;
336 
337 	sock_init_data(sock, sk);
338 	sk->sk_destruct = mctp_sk_destruct;
339 
340 	rc = 0;
341 	if (sk->sk_prot->init)
342 		rc = sk->sk_prot->init(sk);
343 
344 	if (rc)
345 		goto err_sk_put;
346 
347 	return 0;
348 
349 err_sk_put:
350 	sock_orphan(sk);
351 	sock_put(sk);
352 	return rc;
353 }
354 
355 static struct net_proto_family mctp_pf = {
356 	.family = PF_MCTP,
357 	.create = mctp_pf_create,
358 	.owner = THIS_MODULE,
359 };
360 
mctp_init(void)361 static __init int mctp_init(void)
362 {
363 	int rc;
364 
365 	/* ensure our uapi tag definitions match the header format */
366 	BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
367 	BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
368 
369 	pr_info("mctp: management component transport protocol core\n");
370 
371 	rc = sock_register(&mctp_pf);
372 	if (rc)
373 		return rc;
374 
375 	rc = proto_register(&mctp_proto, 0);
376 	if (rc)
377 		goto err_unreg_sock;
378 
379 	rc = mctp_routes_init();
380 	if (rc)
381 		goto err_unreg_proto;
382 
383 	rc = mctp_neigh_init();
384 	if (rc)
385 		goto err_unreg_routes;
386 
387 	mctp_device_init();
388 
389 	return 0;
390 
391 err_unreg_routes:
392 	mctp_routes_exit();
393 err_unreg_proto:
394 	proto_unregister(&mctp_proto);
395 err_unreg_sock:
396 	sock_unregister(PF_MCTP);
397 
398 	return rc;
399 }
400 
mctp_exit(void)401 static __exit void mctp_exit(void)
402 {
403 	mctp_device_exit();
404 	mctp_neigh_exit();
405 	mctp_routes_exit();
406 	proto_unregister(&mctp_proto);
407 	sock_unregister(PF_MCTP);
408 }
409 
410 module_init(mctp_init);
411 module_exit(mctp_exit);
412 
413 MODULE_DESCRIPTION("MCTP core");
414 MODULE_LICENSE("GPL v2");
415 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
416 
417 MODULE_ALIAS_NETPROTO(PF_MCTP);
418