• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020 HiSilicon (Shanghai) Technologies CO., LIMITED.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  * Description: implementation for socket filter
15  * Author: none
16  * Create: 2020
17  */
18 
19 #include "lwip/opt.h"
20 
21 #if LWIP_SOCK_FILTER
22 
23 #include <string.h>
24 #include "lwip/filter.h"
25 #include "lwip/api.h"
26 
27 #define get_unaligned_be16(p)   (((((u8_t*)p)[0]) << 8) | (((u8_t*)p)[1]))
28 #define get_unaligned_be32(p)   ((((u8_t*)p)[0] << 24) | (((u8_t*)p)[1] << 16) |(((u8_t*)p)[2] << 8) | (((u8_t*)p)[3]))
29 
pbuf_header_pointer(struct pbuf * pbuf,u16_t offset,u16_t len,void * buffer)30 static void *pbuf_header_pointer(struct pbuf *pbuf, u16_t offset, u16_t len, void *buffer)
31 {
32   if (pbuf == NULL) {
33     return NULL;
34   }
35   /* copy from the first pbuf in the chain */
36   if (pbuf->len - offset >= len) {
37     return (u8_t*)pbuf->payload + offset;
38   }
39 
40   /* handle the chained pbufs */
41   if (pbuf_copy_partial(pbuf, buffer, len, offset)) {
42     return buffer;
43   }
44 
45   return NULL;
46 }
47 
load_pointer(struct pbuf * pbuf,u16_t k,u16_t size,void * buffer)48 static inline void *load_pointer(struct pbuf *pbuf, u16_t k, u16_t size, void *buffer)
49 {
50   return pbuf_header_pointer(pbuf, k, size, buffer);
51 }
52 
53 /*
54  *  sock_run_filter - run a filter on a socket
55  *  @pbuf: buffer to run the filter on
56  *  @filter: filter to apply
57  *  @flen: length of filter
58  *
59  * Decode and apply filter instructions to the pbuf->payload.
60  * Return length to keep, 0 for none. pbuf is the data we are
61  * filtering, filter is the array of filter instructions, and
62  * len is the number of filter blocks in the array.
63  */
sock_run_filter(struct pbuf * pbuf,struct sock_filter * filter,u16_t len)64 u32_t sock_run_filter(struct pbuf *pbuf, struct sock_filter *filter, u16_t len)
65 {
66   void *ptr = NULL;
67   u32_t X = 0;
68   u32_t A = 0;
69   u32_t mem[LSF_MEMWORDS] = {0};
70   u32_t tmp;
71   u32_t pc;
72   u32_t k;
73   struct sock_filter *entry = NULL;
74   if (pbuf == NULL) {
75     return 0;
76   }
77   for (pc = 0; pc < len; pc++) {
78     entry = &filter[pc];
79 
80     switch (entry->code) {
81       case LSF_JMP | LSF_JA:
82         pc += entry->k;
83         continue;
84       case LSF_JMP | LSF_JGT | LSF_K:
85         pc += (u32_t)((A > entry->k) ? entry->jt : entry->jf);
86         continue;
87       case LSF_JMP | LSF_JGE | LSF_K:
88         pc += (u32_t)((A >= entry->k) ? entry->jt : entry->jf);
89         continue;
90       case LSF_JMP | LSF_JEQ | LSF_K:
91         pc += (u32_t)((A == entry->k) ? entry->jt : entry->jf);
92         continue;
93       case LSF_JMP | LSF_JSET | LSF_K:
94         pc += (u32_t)((A & entry->k) ? entry->jt : entry->jf);
95         continue;
96       case LSF_JMP | LSF_JGT | LSF_X:
97         pc += (u32_t)((A > X) ? entry->jt : entry->jf);
98         continue;
99       case LSF_JMP | LSF_JGE | LSF_X:
100         pc += (u32_t)((A >= X) ? entry->jt : entry->jf);
101         continue;
102       case LSF_JMP | LSF_JEQ | LSF_X:
103         pc += (u32_t)((A == X) ? entry->jt : entry->jf);
104         continue;
105       case LSF_JMP | LSF_JSET | LSF_X:
106         pc += (u32_t)((A & X) ? entry->jt : entry->jf);
107         continue;
108       case LSF_ALU | LSF_ADD | LSF_X:
109         A += X;
110         continue;
111       case LSF_ALU | LSF_ADD | LSF_K:
112         A += entry->k;
113         continue;
114       case LSF_ALU | LSF_SUB | LSF_X:
115         A -= X;
116         continue;
117       case LSF_ALU | LSF_SUB | LSF_K:
118         A -= entry->k;
119         continue;
120       case LSF_ALU | LSF_MUL | LSF_X:
121         A *= X;
122         continue;
123       case LSF_ALU | LSF_MUL | LSF_K:
124         A *= entry->k;
125         continue;
126       case LSF_ALU | LSF_DIV | LSF_X:
127         if (X == 0) {
128           return 0;
129         }
130         A /= X;
131         continue;
132       case LSF_ALU | LSF_DIV | LSF_K:
133         A /= entry->k;
134         continue;
135       case LSF_ALU | LSF_AND | LSF_X:
136         A &= X;
137         continue;
138       case LSF_ALU | LSF_AND | LSF_K:
139         A &= entry->k;
140         continue;
141       case LSF_ALU | LSF_OR | LSF_X:
142         A |= X;
143         continue;
144       case LSF_ALU | LSF_OR | LSF_K:
145         A |= entry->k;
146         continue;
147       case LSF_ALU | LSF_LSH | LSF_X:
148         A <<= X;
149         continue;
150       case LSF_ALU | LSF_LSH | LSF_K:
151         A <<= entry->k;
152         continue;
153       case LSF_ALU | LSF_RSH | LSF_X:
154         A >>= X;
155         continue;
156       case LSF_ALU | LSF_RSH | LSF_K:
157         A >>= entry->k;
158         continue;
159       case LSF_ALU | LSF_NEG:
160         A = (u32_t)(-(int32_t)A);
161         continue;
162       case LSF_LD | LSF_W | LSF_ABS:
163         k = entry->k;
164 load_w:
165         ptr = load_pointer(pbuf, (u16_t)k, 4, &tmp); // read 4 bytes from pbuf to tmp by offset k
166         if (ptr != NULL) {
167           A = (u32_t)get_unaligned_be32(ptr);
168           continue;
169         }
170         break;
171       case LSF_LD | LSF_H | LSF_ABS:
172         k = entry->k;
173 load_h:
174         ptr = load_pointer(pbuf, (u16_t)k, 2, &tmp); // read 2 bytes from pbuf to tmp by offset k
175         if (ptr != NULL) {
176           A = (u32_t)get_unaligned_be16(ptr);
177           continue;
178         }
179         break;
180       case LSF_LD | LSF_B | LSF_ABS:
181         k = entry->k;
182 load_b:
183         ptr = load_pointer(pbuf, (u16_t)k, 1, &tmp); // read 1 bytes from pbuf to tmp by offset k
184         if (ptr != NULL) {
185           A = *(u8_t *)ptr;
186           continue;
187         }
188         break;
189       case LSF_LD | LSF_W | LSF_LEN:
190         A = pbuf->tot_len;
191         continue;
192       case LSF_LDX | LSF_W | LSF_LEN:
193         X = pbuf->tot_len;
194         continue;
195       case LSF_LD | LSF_W | LSF_IND:
196         k = X + entry->k;
197         goto load_w;
198       case LSF_LD | LSF_H | LSF_IND:
199         k = X + entry->k;
200         goto load_h;
201       case LSF_LD | LSF_B | LSF_IND:
202         k = X + entry->k;
203         goto load_b;
204       case LSF_LD | LSF_IMM:
205         A = entry->k;
206         continue;
207       case LSF_LDX | LSF_IMM:
208         X = entry->k;
209         continue;
210       case LSF_LD | LSF_MEM:
211         A = mem[entry->k];
212         continue;
213       case LSF_LDX | LSF_MEM:
214         X = mem[entry->k];
215         continue;
216       case LSF_MISC | LSF_TAX:
217         X = A;
218         continue;
219       case LSF_MISC | LSF_TXA:
220         A = X;
221         continue;
222       case LSF_ST:
223         mem[entry->k] = A;
224         continue;
225       case LSF_STX:
226         mem[entry->k] = X;
227         continue;
228       case LSF_RET | LSF_K:
229         return entry->k;
230       case LSF_RET | LSF_A:
231         return A;
232       default:
233         return 0;
234     }
235   }
236 
237   return 0;
238 }
239 
240 /*
241  *  sock_filter - run a packet through a socket filter
242  *
243  * Run the filter code and then cut pbuf->payload to correct size returned by
244  * sk_run_filter. If pkt_len is 0 we toss packet. If pbuf->tot_len is smaller
245  * than pkt_len we keep whole pbuf->payload. This is the socket level
246  * wrapper to sk_run_filter. It returns 0 if the packet should
247  * be accepted or EPERM if the packet should be tossed.
248  *
249  */
sock_filter(struct netconn * conn,struct pbuf * pbuf)250 s32_t sock_filter(struct netconn *conn, struct pbuf *pbuf)
251 {
252   s32_t err = ERR_OK;
253   u32_t pkt_len;
254 
255   if (conn->sk_filter.filter != NULL) {
256     pkt_len = sock_run_filter(pbuf, conn->sk_filter.filter, conn->sk_filter.len);
257     if (pkt_len == 0) {
258       err = EPERM;
259     }
260   }
261 
262   return err;
263 }
264 
265 /*
266  *  sock_check_filter - verify socket filter code
267  *  @filter: filter to be verified
268  *  @flen: filter length
269  *
270  *  make sure the user filter code was legal, checking include
271  *  1) the filter must contain no illegal instructions,
272  *  2) no references or jumps that are out of range,
273  *  3) and must end with a RET instruction.
274  *
275  *  All jumps are forward because they are not signed.
276  *
277  *  return 0 if the filter is legal, EINVAL if not.
278  */
sock_check_filter(struct sock_filter * filter,s32_t len)279 s32_t sock_check_filter(struct sock_filter *filter, s32_t len)
280 {
281   struct sock_filter *entry = NULL;
282   s32_t pc;
283 
284   if (len < 1 || len > LSF_MAXINSNS) {
285     return EINVAL;
286   }
287 
288   for (pc = 0; pc < len; pc++) {
289     entry = &filter[pc];
290 
291     switch (entry->code) {
292       case LSF_LD | LSF_W | LSF_ABS:
293       case LSF_LD | LSF_H | LSF_ABS:
294       case LSF_LD | LSF_B | LSF_ABS:
295       case LSF_LD | LSF_W | LSF_LEN:
296       case LSF_LD | LSF_W | LSF_IND:
297       case LSF_LD | LSF_H | LSF_IND:
298       case LSF_LD | LSF_B | LSF_IND:
299       case LSF_LD | LSF_IMM:
300       case LSF_LDX | LSF_W | LSF_LEN:
301       case LSF_LDX | LSF_IMM:
302       case LSF_ALU | LSF_ADD | LSF_K:
303       case LSF_ALU | LSF_ADD | LSF_X:
304       case LSF_ALU | LSF_SUB | LSF_K:
305       case LSF_ALU | LSF_SUB | LSF_X:
306       case LSF_ALU | LSF_MUL | LSF_K:
307       case LSF_ALU | LSF_MUL | LSF_X:
308       case LSF_ALU | LSF_DIV | LSF_X:
309       case LSF_ALU | LSF_AND | LSF_K:
310       case LSF_ALU | LSF_AND | LSF_X:
311       case LSF_ALU | LSF_OR | LSF_K:
312       case LSF_ALU | LSF_OR | LSF_X:
313       case LSF_ALU | LSF_LSH | LSF_K:
314       case LSF_ALU | LSF_LSH | LSF_X:
315       case LSF_ALU | LSF_RSH | LSF_K:
316       case LSF_ALU | LSF_RSH | LSF_X:
317       case LSF_ALU | LSF_NEG:
318       case LSF_RET | LSF_K:
319       case LSF_RET | LSF_A:
320       case LSF_MISC | LSF_TAX:
321       case LSF_MISC | LSF_TXA:
322         break;
323 
324       /* special checks needed for following instructions */
325       case LSF_ALU | LSF_DIV | LSF_K:
326         if (entry->k == 0) {
327           return EINVAL;
328         }
329         break;
330 
331       case LSF_LD | LSF_MEM:
332       case LSF_LDX | LSF_MEM:
333       case LSF_ST:
334       case LSF_STX:
335         /* invalid memory addresses */
336         if (entry->k >= LSF_MEMWORDS) {
337           return EINVAL;
338         }
339         break;
340       case LSF_JMP | LSF_JA:
341         if (entry->k >= (unsigned)(len - pc - 1)) {
342           return EINVAL;
343         }
344         break;
345 
346       case LSF_JMP | LSF_JEQ | LSF_K:
347       case LSF_JMP | LSF_JEQ | LSF_X:
348       case LSF_JMP | LSF_JGE | LSF_K:
349       case LSF_JMP | LSF_JGE | LSF_X:
350       case LSF_JMP | LSF_JGT | LSF_K:
351       case LSF_JMP | LSF_JGT | LSF_X:
352       case LSF_JMP | LSF_JSET | LSF_K:
353       case LSF_JMP | LSF_JSET | LSF_X:
354         if (pc + entry->jt + 1 >= len || pc + entry->jf + 1 >= len) {
355           return EINVAL;
356         }
357         break;
358 
359       default:
360         return EINVAL;
361     }
362   }
363 
364   return (LSF_CLASS(filter[len - 1].code) == LSF_RET) ? 0 : EINVAL;
365 }
366 
367 /*
368  *  sock_attach_filter - attach a user filter
369  *  If an error occurs or there is insufficient memory for the filter a posix
370  *  errno code is returned. On success the return is zero.
371  */
sock_attach_filter(struct sock_fprog * fprog,struct netconn * conn)372 s32_t sock_attach_filter(struct sock_fprog *fprog, struct netconn *conn)
373 {
374   u32_t fsize = sizeof(struct sock_filter) * fprog->len;
375   struct sock_filter *fp = NULL;
376   s32_t err;
377 
378   if (fprog->filter == NULL || fprog->len == 0) {
379     return EINVAL;
380   }
381 
382   err = sock_check_filter(fprog->filter, fprog->len);
383   if (err) {
384     return err;
385   }
386 
387   fp = mem_malloc(fsize);
388   if (fp == NULL) {
389     return ENOMEM;
390   }
391 
392   (void)memcpy_s(fp, fsize, fprog->filter, fsize);
393 
394   if (conn->sk_filter.filter != NULL) {
395     mem_free(conn->sk_filter.filter);
396     conn->sk_filter.filter = NULL;
397     conn->sk_filter.len = 0;
398   }
399 
400   conn->sk_filter.len = fprog->len;
401   conn->sk_filter.filter = fp;
402   return 0;
403 }
404 
sock_detach_filter(struct netconn * conn)405 s32_t sock_detach_filter(struct netconn *conn)
406 {
407   s32_t ret = ENOENT;
408 
409   if (conn->sk_filter.filter) {
410     mem_free(conn->sk_filter.filter);
411     conn->sk_filter.filter = NULL;
412     conn->sk_filter.len = 0;
413     ret = 0;
414   }
415   return ret;
416 }
417 
418 #endif  /* LWIP_SOCK_FILTER */
419