1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Anton Protopopov
3 //
4 // Based on tcpconnect(8) from BCC by Brendan Gregg
5 #include <sys/resource.h>
6 #include <arpa/inet.h>
7 #include <argp.h>
8 #include <signal.h>
9 #include <limits.h>
10 #include <unistd.h>
11 #include <time.h>
12 #include <bpf/bpf.h>
13 #include "tcpconnect.h"
14 #include "tcpconnect.skel.h"
15 #include "trace_helpers.h"
16 #include "map_helpers.h"
17
18 #define warn(...) fprintf(stderr, __VA_ARGS__)
19
20 static volatile sig_atomic_t exiting = 0;
21
22 const char *argp_program_version = "tcpconnect 0.1";
23 const char *argp_program_bug_address =
24 "https://github.com/iovisor/bcc/tree/master/libbpf-tools";
25 static const char argp_program_doc[] =
26 "\ntcpconnect: Count/Trace active tcp connections\n"
27 "\n"
28 "EXAMPLES:\n"
29 " tcpconnect # trace all TCP connect()s\n"
30 " tcpconnect -t # include timestamps\n"
31 " tcpconnect -p 181 # only trace PID 181\n"
32 " tcpconnect -P 80 # only trace port 80\n"
33 " tcpconnect -P 80,81 # only trace port 80 and 81\n"
34 " tcpconnect -U # include UID\n"
35 " tcpconnect -u 1000 # only trace UID 1000\n"
36 " tcpconnect -c # count connects per src, dest, port\n"
37 " tcpconnect --C mappath # only trace cgroups in the map\n"
38 " tcpconnect --M mappath # only trace mount namespaces in the map\n"
39 ;
40
get_int(const char * arg,int * ret,int min,int max)41 static int get_int(const char *arg, int *ret, int min, int max)
42 {
43 char *end;
44 long val;
45
46 errno = 0;
47 val = strtol(arg, &end, 10);
48 if (errno) {
49 warn("strtol: %s: %s\n", arg, strerror(errno));
50 return -1;
51 } else if (end == arg || val < min || val > max) {
52 return -1;
53 }
54 if (ret)
55 *ret = val;
56 return 0;
57 }
58
get_ints(const char * arg,int * size,int * ret,int min,int max)59 static int get_ints(const char *arg, int *size, int *ret, int min, int max)
60 {
61 const char *argp = arg;
62 int max_size = *size;
63 int sz = 0;
64 char *end;
65 long val;
66
67 while (sz < max_size) {
68 errno = 0;
69 val = strtol(argp, &end, 10);
70 if (errno) {
71 warn("strtol: %s: %s\n", arg, strerror(errno));
72 return -1;
73 } else if (end == arg || val < min || val > max) {
74 return -1;
75 }
76 ret[sz++] = val;
77 if (*end == 0)
78 break;
79 argp = end + 1;
80 }
81
82 *size = sz;
83 return 0;
84 }
85
get_uint(const char * arg,unsigned int * ret,unsigned int min,unsigned int max)86 static int get_uint(const char *arg, unsigned int *ret,
87 unsigned int min, unsigned int max)
88 {
89 char *end;
90 long val;
91
92 errno = 0;
93 val = strtoul(arg, &end, 10);
94 if (errno) {
95 warn("strtoul: %s: %s\n", arg, strerror(errno));
96 return -1;
97 } else if (end == arg || val < min || val > max) {
98 return -1;
99 }
100 if (ret)
101 *ret = val;
102 return 0;
103 }
104
105 static const struct argp_option opts[] = {
106 { "verbose", 'v', NULL, 0, "Verbose debug output" },
107 { "timestamp", 't', NULL, 0, "Include timestamp on output" },
108 { "count", 'c', NULL, 0, "Count connects per src ip and dst ip/port" },
109 { "print-uid", 'U', NULL, 0, "Include UID on output" },
110 { "pid", 'p', "PID", 0, "Process PID to trace" },
111 { "uid", 'u', "UID", 0, "Process UID to trace" },
112 { "port", 'P', "PORTS", 0,
113 "Comma-separated list of destination ports to trace" },
114 { "cgroupmap", 'C', "PATH", 0, "trace cgroups in this map" },
115 { "mntnsmap", 'M', "PATH", 0, "trace mount namespaces in this map" },
116 { NULL, 'h', NULL, OPTION_HIDDEN, "Show the full help" },
117 {},
118 };
119
120 static struct env {
121 bool verbose;
122 bool count;
123 bool print_timestamp;
124 bool print_uid;
125 pid_t pid;
126 uid_t uid;
127 int nports;
128 int ports[MAX_PORTS];
129 } env = {
130 .uid = (uid_t) -1,
131 };
132
parse_arg(int key,char * arg,struct argp_state * state)133 static error_t parse_arg(int key, char *arg, struct argp_state *state)
134 {
135 int err;
136 int nports;
137
138 switch (key) {
139 case 'h':
140 argp_state_help(state, stderr, ARGP_HELP_STD_HELP);
141 break;
142 case 'v':
143 env.verbose = true;
144 break;
145 case 'c':
146 env.count = true;
147 break;
148 case 't':
149 env.print_timestamp = true;
150 break;
151 case 'U':
152 env.print_uid = true;
153 break;
154 case 'p':
155 err = get_int(arg, &env.pid, 1, INT_MAX);
156 if (err) {
157 warn("invalid PID: %s\n", arg);
158 argp_usage(state);
159 }
160 break;
161 case 'u':
162 err = get_uint(arg, &env.uid, 0, (uid_t) -2);
163 if (err) {
164 warn("invalid UID: %s\n", arg);
165 argp_usage(state);
166 }
167 break;
168 case 'P':
169 nports = MAX_PORTS;
170 err = get_ints(arg, &nports, env.ports, 1, 65535);
171 if (err) {
172 warn("invalid PORT_LIST: %s\n", arg);
173 argp_usage(state);
174 }
175 env.nports = nports;
176 break;
177 case 'C':
178 warn("not implemented: --cgroupmap");
179 break;
180 case 'M':
181 warn("not implemented: --mntnsmap");
182 break;
183 default:
184 return ARGP_ERR_UNKNOWN;
185 }
186 return 0;
187 }
188
libbpf_print_fn(enum libbpf_print_level level,const char * format,va_list args)189 static int libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args)
190 {
191 if (level == LIBBPF_DEBUG && !env.verbose)
192 return 0;
193 return vfprintf(stderr, format, args);
194 }
195
sig_int(int signo)196 static void sig_int(int signo)
197 {
198 exiting = 1;
199 }
200
print_count_ipv4(int map_fd)201 static void print_count_ipv4(int map_fd)
202 {
203 static struct ipv4_flow_key keys[MAX_ENTRIES];
204 __u32 value_size = sizeof(__u64);
205 __u32 key_size = sizeof(keys[0]);
206 static struct ipv4_flow_key zero;
207 static __u64 counts[MAX_ENTRIES];
208 char s[INET_ADDRSTRLEN];
209 char d[INET_ADDRSTRLEN];
210 __u32 i, n = MAX_ENTRIES;
211 struct in_addr src;
212 struct in_addr dst;
213
214 if (dump_hash(map_fd, keys, key_size, counts, value_size, &n, &zero)) {
215 warn("dump_hash: %s", strerror(errno));
216 return;
217 }
218
219 for (i = 0; i < n; i++) {
220 src.s_addr = keys[i].saddr;
221 dst.s_addr = keys[i].daddr;
222
223 printf("%-25s %-25s %-20d %-10llu\n",
224 inet_ntop(AF_INET, &src, s, sizeof(s)),
225 inet_ntop(AF_INET, &dst, d, sizeof(d)),
226 ntohs(keys[i].dport), counts[i]);
227 }
228 }
229
print_count_ipv6(int map_fd)230 static void print_count_ipv6(int map_fd)
231 {
232 static struct ipv6_flow_key keys[MAX_ENTRIES];
233 __u32 value_size = sizeof(__u64);
234 __u32 key_size = sizeof(keys[0]);
235 static struct ipv6_flow_key zero;
236 static __u64 counts[MAX_ENTRIES];
237 char s[INET6_ADDRSTRLEN];
238 char d[INET6_ADDRSTRLEN];
239 __u32 i, n = MAX_ENTRIES;
240 struct in6_addr src;
241 struct in6_addr dst;
242
243 if (dump_hash(map_fd, keys, key_size, counts, value_size, &n, &zero)) {
244 warn("dump_hash: %s", strerror(errno));
245 return;
246 }
247
248 for (i = 0; i < n; i++) {
249 memcpy(src.s6_addr, keys[i].saddr, sizeof(src.s6_addr));
250 memcpy(dst.s6_addr, keys[i].daddr, sizeof(src.s6_addr));
251
252 printf("%-25s %-25s %-20d %-10llu\n",
253 inet_ntop(AF_INET6, &src, s, sizeof(s)),
254 inet_ntop(AF_INET6, &dst, d, sizeof(d)),
255 ntohs(keys[i].dport), counts[i]);
256 }
257 }
258
print_count(int map_fd_ipv4,int map_fd_ipv6)259 static void print_count(int map_fd_ipv4, int map_fd_ipv6)
260 {
261 static const char *header_fmt = "\n%-25s %-25s %-20s %-10s\n";
262
263 while (!exiting)
264 pause();
265
266 printf(header_fmt, "LADDR", "RADDR", "RPORT", "CONNECTS");
267 print_count_ipv4(map_fd_ipv4);
268 print_count_ipv6(map_fd_ipv6);
269 }
270
print_events_header()271 static void print_events_header()
272 {
273 if (env.print_timestamp)
274 printf("%-9s", "TIME(s)");
275 if (env.print_uid)
276 printf("%-6s", "UID");
277 printf("%-6s %-12s %-2s %-16s %-16s %-4s\n",
278 "PID", "COMM", "IP", "SADDR", "DADDR", "DPORT");
279 }
280
handle_event(void * ctx,int cpu,void * data,__u32 data_sz)281 static void handle_event(void *ctx, int cpu, void *data, __u32 data_sz)
282 {
283 const struct event *event = data;
284 char src[INET6_ADDRSTRLEN];
285 char dst[INET6_ADDRSTRLEN];
286 union {
287 struct in_addr x4;
288 struct in6_addr x6;
289 } s, d;
290 static __u64 start_ts;
291
292 if (event->af == AF_INET) {
293 s.x4.s_addr = event->saddr_v4;
294 d.x4.s_addr = event->daddr_v4;
295 } else if (event->af == AF_INET6) {
296 memcpy(&s.x6.s6_addr, event->saddr_v6, sizeof(s.x6.s6_addr));
297 memcpy(&d.x6.s6_addr, event->daddr_v6, sizeof(d.x6.s6_addr));
298 } else {
299 warn("broken event: event->af=%d", event->af);
300 return;
301 }
302
303 if (env.print_timestamp) {
304 if (start_ts == 0)
305 start_ts = event->ts_us;
306 printf("%-9.3f", (event->ts_us - start_ts) / 1000000.0);
307 }
308
309 if (env.print_uid)
310 printf("%-6d", event->uid);
311
312 printf("%-6d %-12.12s %-2d %-16s %-16s %-4d\n",
313 event->pid, event->task,
314 event->af == AF_INET ? 4 : 6,
315 inet_ntop(event->af, &s, src, sizeof(src)),
316 inet_ntop(event->af, &d, dst, sizeof(dst)),
317 ntohs(event->dport));
318 }
319
handle_lost_events(void * ctx,int cpu,__u64 lost_cnt)320 static void handle_lost_events(void *ctx, int cpu, __u64 lost_cnt)
321 {
322 warn("Lost %llu events on CPU #%d!\n", lost_cnt, cpu);
323 }
324
print_events(int perf_map_fd)325 static void print_events(int perf_map_fd)
326 {
327 struct perf_buffer *pb;
328 int err;
329
330 pb = perf_buffer__new(perf_map_fd, 128,
331 handle_event, handle_lost_events, NULL, NULL);
332 if (!pb) {
333 err = -errno;
334 warn("failed to open perf buffer: %d\n", err);
335 goto cleanup;
336 }
337
338 print_events_header();
339 while (!exiting) {
340 err = perf_buffer__poll(pb, 100);
341 if (err < 0 && err != -EINTR) {
342 warn("error polling perf buffer: %s\n", strerror(-err));
343 goto cleanup;
344 }
345 /* reset err to return 0 if exiting */
346 err = 0;
347 }
348
349 cleanup:
350 perf_buffer__free(pb);
351 }
352
main(int argc,char ** argv)353 int main(int argc, char **argv)
354 {
355 static const struct argp argp = {
356 .options = opts,
357 .parser = parse_arg,
358 .doc = argp_program_doc,
359 .args_doc = NULL,
360 };
361 struct tcpconnect_bpf *obj;
362 int i, err;
363
364 err = argp_parse(&argp, argc, argv, 0, NULL, NULL);
365 if (err)
366 return err;
367
368 libbpf_set_strict_mode(LIBBPF_STRICT_ALL);
369 libbpf_set_print(libbpf_print_fn);
370
371 obj = tcpconnect_bpf__open();
372 if (!obj) {
373 warn("failed to open BPF object\n");
374 return 1;
375 }
376
377 if (env.count)
378 obj->rodata->do_count = true;
379 if (env.pid)
380 obj->rodata->filter_pid = env.pid;
381 if (env.uid != (uid_t) -1)
382 obj->rodata->filter_uid = env.uid;
383 if (env.nports > 0) {
384 obj->rodata->filter_ports_len = env.nports;
385 for (i = 0; i < env.nports; i++) {
386 obj->rodata->filter_ports[i] = htons(env.ports[i]);
387 }
388 }
389
390 err = tcpconnect_bpf__load(obj);
391 if (err) {
392 warn("failed to load BPF object: %d\n", err);
393 goto cleanup;
394 }
395
396 err = tcpconnect_bpf__attach(obj);
397 if (err) {
398 warn("failed to attach BPF programs: %s\n", strerror(-err));
399 goto cleanup;
400 }
401
402 if (signal(SIGINT, sig_int) == SIG_ERR) {
403 warn("can't set signal handler: %s\n", strerror(errno));
404 err = 1;
405 goto cleanup;
406 }
407
408 if (env.count) {
409 print_count(bpf_map__fd(obj->maps.ipv4_count),
410 bpf_map__fd(obj->maps.ipv6_count));
411 } else {
412 print_events(bpf_map__fd(obj->maps.events));
413 }
414
415 cleanup:
416 tcpconnect_bpf__destroy(obj);
417
418 return err != 0;
419 }
420