• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2# @lint-avoid-python-3-compatibility-imports
3#
4# tcprtt    Summarize TCP RTT as a histogram. For Linux, uses BCC, eBPF.
5#
6# USAGE: tcprtt [-h] [-T] [-D] [-m] [-i INTERVAL] [-d DURATION]
7#           [-p LPORT] [-P RPORT] [-a LADDR] [-A RADDR] [-b] [-B] [-e]
8#           [-4 | -6]
9#
10# Copyright (c) 2020 zhenwei pi
11# Licensed under the Apache License, Version 2.0 (the "License")
12#
13# 23-AUG-2020  zhenwei pi  Created this.
14
15from __future__ import print_function
16from bcc import BPF
17from time import sleep, strftime
18from socket import inet_ntop, AF_INET
19import socket, struct
20import argparse
21import ctypes
22
23# arguments
24examples = """examples:
25    ./tcprtt            # summarize TCP RTT
26    ./tcprtt -i 1 -d 10 # print 1 second summaries, 10 times
27    ./tcprtt -m -T      # summarize in millisecond, and timestamps
28    ./tcprtt -p         # filter for local port
29    ./tcprtt -P         # filter for remote port
30    ./tcprtt -a         # filter for local address
31    ./tcprtt -A         # filter for remote address
32    ./tcprtt -b         # show sockets histogram by local address
33    ./tcprtt -B         # show sockets histogram by remote address
34    ./tcprtt -D         # show debug bpf text
35    ./tcprtt -e         # show extension summary(average)
36    ./tcprtt -4         # trace only IPv4 family
37    ./tcprtt -6         # trace only IPv6 family
38"""
39parser = argparse.ArgumentParser(
40    description="Summarize TCP RTT as a histogram",
41    formatter_class=argparse.RawDescriptionHelpFormatter,
42    epilog=examples)
43parser.add_argument("-i", "--interval",
44    help="summary interval, seconds")
45parser.add_argument("-d", "--duration", type=int, default=99999,
46    help="total duration of trace, seconds")
47parser.add_argument("-T", "--timestamp", action="store_true",
48    help="include timestamp on output")
49parser.add_argument("-m", "--milliseconds", action="store_true",
50    help="millisecond histogram")
51parser.add_argument("-p", "--lport",
52    help="filter for local port")
53parser.add_argument("-P", "--rport",
54    help="filter for remote port")
55parser.add_argument("-a", "--laddr",
56    help="filter for local address")
57parser.add_argument("-A", "--raddr",
58    help="filter for remote address")
59parser.add_argument("-b", "--byladdr", action="store_true",
60    help="show sockets histogram by local address")
61parser.add_argument("-B", "--byraddr", action="store_true",
62    help="show sockets histogram by remote address")
63parser.add_argument("-e", "--extension", action="store_true",
64    help="show extension summary(average)")
65parser.add_argument("-D", "--debug", action="store_true",
66    help="print BPF program before starting (for debugging purposes)")
67group = parser.add_mutually_exclusive_group()
68group.add_argument("-4", "--ipv4", action="store_true",
69    help="trace IPv4 family only")
70group.add_argument("-6", "--ipv6", action="store_true",
71    help="trace IPv6 family only")
72parser.add_argument("--ebpf", action="store_true",
73    help=argparse.SUPPRESS)
74args = parser.parse_args()
75if not args.interval:
76    args.interval = args.duration
77
78# define BPF program
79bpf_text = """
80#ifndef KBUILD_MODNAME
81#define KBUILD_MODNAME "bcc"
82#endif
83#include <uapi/linux/ptrace.h>
84#include <linux/tcp.h>
85#include <net/sock.h>
86#include <net/inet_sock.h>
87#include <bcc/proto.h>
88
89typedef struct sock_key {
90    u64 addr;
91    u64 slot;
92} sock_key_t;
93
94typedef struct sock_latenty {
95    u64 latency;
96    u64 count;
97} sock_latency_t;
98
99BPF_HISTOGRAM(hist_srtt, sock_key_t);
100BPF_HASH(latency, u64, sock_latency_t);
101
102int trace_tcp_rcv(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb)
103{
104    struct tcp_sock *ts = tcp_sk(sk);
105    u32 srtt = ts->srtt_us >> 3;
106    const struct inet_sock *inet = inet_sk(sk);
107
108    /* filters */
109    u16 sport = 0;
110    u16 dport = 0;
111    u32 saddr = 0;
112    u32 daddr = 0;
113    u16 family = 0;
114
115    /* for histogram */
116    sock_key_t key;
117
118    /* for avg latency, if no saddr/daddr specified, use 0(addr) as key */
119    u64 addr = 0;
120
121    bpf_probe_read_kernel(&sport, sizeof(sport), (void *)&inet->inet_sport);
122    bpf_probe_read_kernel(&dport, sizeof(dport), (void *)&inet->inet_dport);
123    bpf_probe_read_kernel(&saddr, sizeof(saddr), (void *)&inet->inet_saddr);
124    bpf_probe_read_kernel(&daddr, sizeof(daddr), (void *)&inet->inet_daddr);
125    bpf_probe_read_kernel(&family, sizeof(family), (void *)&sk->__sk_common.skc_family);
126
127    LPORTFILTER
128    RPORTFILTER
129    LADDRFILTER
130    RADDRFILTER
131    FAMILYFILTER
132
133    FACTOR
134
135    STORE_HIST
136    key.slot = bpf_log2l(srtt);
137    hist_srtt.atomic_increment(key);
138
139    STORE_LATENCY
140
141    return 0;
142}
143"""
144
145# filter for local port
146if args.lport:
147    bpf_text = bpf_text.replace('LPORTFILTER',
148        """if (ntohs(sport) != %d)
149        return 0;""" % int(args.lport))
150else:
151    bpf_text = bpf_text.replace('LPORTFILTER', '')
152
153# filter for remote port
154if args.rport:
155    bpf_text = bpf_text.replace('RPORTFILTER',
156        """if (ntohs(dport) != %d)
157        return 0;""" % int(args.rport))
158else:
159    bpf_text = bpf_text.replace('RPORTFILTER', '')
160
161# filter for local address
162if args.laddr:
163    bpf_text = bpf_text.replace('LADDRFILTER',
164        """if (saddr != %d)
165        return 0;""" % struct.unpack("=I", socket.inet_aton(args.laddr))[0])
166else:
167    bpf_text = bpf_text.replace('LADDRFILTER', '')
168
169# filter for remote address
170if args.raddr:
171    bpf_text = bpf_text.replace('RADDRFILTER',
172        """if (daddr != %d)
173        return 0;""" % struct.unpack("=I", socket.inet_aton(args.raddr))[0])
174else:
175    bpf_text = bpf_text.replace('RADDRFILTER', '')
176if args.ipv4:
177    bpf_text = bpf_text.replace('FAMILYFILTER',
178        'if (family != AF_INET) { return 0; }')
179elif args.ipv6:
180    bpf_text = bpf_text.replace('FAMILYFILTER',
181        'if (family != AF_INET6) { return 0; }')
182else:
183    bpf_text = bpf_text.replace('FAMILYFILTER', '')
184# show msecs or usecs[default]
185if args.milliseconds:
186    bpf_text = bpf_text.replace('FACTOR', 'srtt /= 1000;')
187    label = "msecs"
188else:
189    bpf_text = bpf_text.replace('FACTOR', '')
190    label = "usecs"
191
192print_header = "srtt"
193# show byladdr/byraddr histogram
194if args.byladdr:
195    bpf_text = bpf_text.replace('STORE_HIST', 'key.addr = addr = saddr;')
196    print_header = "Local Address"
197elif args.byraddr:
198    bpf_text = bpf_text.replace('STORE_HIST', 'key.addr = addr = daddr;')
199    print_header = "Remote Addres"
200else:
201    bpf_text = bpf_text.replace('STORE_HIST', 'key.addr = addr = 0;')
202    print_header = "All Addresses"
203
204if args.extension:
205    bpf_text = bpf_text.replace('STORE_LATENCY', """
206    sock_latency_t newlat = {0};
207    sock_latency_t *lat;
208    lat = latency.lookup(&addr);
209    if (!lat) {
210        newlat.latency += srtt;
211        newlat.count += 1;
212        latency.update(&addr, &newlat);
213    } else {
214        lat->latency +=srtt;
215        lat->count += 1;
216    }
217    """)
218else:
219    bpf_text = bpf_text.replace('STORE_LATENCY', '')
220
221# debug/dump ebpf enable or not
222if args.debug or args.ebpf:
223    print(bpf_text)
224    if args.ebpf:
225        exit()
226
227# load BPF program
228b = BPF(text=bpf_text)
229b.attach_kprobe(event="tcp_rcv_established", fn_name="trace_tcp_rcv")
230
231print("Tracing TCP RTT... Hit Ctrl-C to end.")
232
233def print_section(addr):
234    addrstr = "*******"
235    if (addr):
236        addrstr = inet_ntop(AF_INET, struct.pack("I", addr))
237
238    avglat = ""
239    if args.extension:
240        lats = b.get_table("latency")
241        lat = lats[ctypes.c_ulong(addr)]
242        avglat = " [AVG %d]" % (lat.latency / lat.count)
243
244    return addrstr + avglat
245
246# output
247exiting = 0 if args.interval else 1
248dist = b.get_table("hist_srtt")
249lathash = b.get_table("latency")
250seconds = 0
251while (1):
252    try:
253        sleep(int(args.interval))
254        seconds = seconds + int(args.interval)
255    except KeyboardInterrupt:
256        exiting = 1
257
258    print()
259    if args.timestamp:
260        print("%-8s\n" % strftime("%H:%M:%S"), end="")
261
262    dist.print_log2_hist(label, section_header=print_header, section_print_fn=print_section)
263    dist.clear()
264    lathash.clear()
265
266    if exiting or seconds >= args.duration:
267        exit()
268