• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1package nghttp2
2
3import (
4	"bufio"
5	"bytes"
6	"context"
7	"crypto/tls"
8	"encoding/binary"
9	"errors"
10	"fmt"
11	"io"
12	"net"
13	"net/http"
14	"net/http/httptest"
15	"net/url"
16	"os"
17	"os/exec"
18	"sort"
19	"strconv"
20	"strings"
21	"syscall"
22	"testing"
23	"time"
24
25	"github.com/lucas-clemente/quic-go/http3"
26	"github.com/tatsuhiro-t/go-nghttp2"
27	"golang.org/x/net/http2"
28	"golang.org/x/net/http2/hpack"
29	"golang.org/x/net/websocket"
30)
31
32const (
33	serverBin  = buildDir + "/src/nghttpx"
34	serverPort = 3009
35	testDir    = sourceDir + "/integration-tests"
36	logDir     = buildDir + "/integration-tests"
37)
38
39func pair(name, value string) hpack.HeaderField {
40	return hpack.HeaderField{
41		Name:  name,
42		Value: value,
43	}
44}
45
46type serverTester struct {
47	cmd           *exec.Cmd // test frontend server process, which is test subject
48	url           string    // test frontend server URL
49	t             *testing.T
50	ts            *httptest.Server // backend server
51	frontendHost  string           // frontend server host
52	backendHost   string           // backend server host
53	conn          net.Conn         // connection to frontend server
54	h2PrefaceSent bool             // HTTP/2 preface was sent in conn
55	nextStreamID  uint32           // next stream ID
56	fr            *http2.Framer    // HTTP/2 framer
57	headerBlkBuf  bytes.Buffer     // buffer to store encoded header block
58	enc           *hpack.Encoder   // HTTP/2 HPACK encoder
59	header        http.Header      // received header fields
60	dec           *hpack.Decoder   // HTTP/2 HPACK decoder
61	authority     string           // server's host:port
62	frCh          chan http2.Frame // used for incoming HTTP/2 frame
63	errCh         chan error
64}
65
66type options struct {
67	// args is the additional arguments to nghttpx.
68	args []string
69	// handler is the handler to handle the request.  It defaults
70	// to noopHandler.
71	handler http.HandlerFunc
72	// connectPort is the server side port where client connection
73	// is made.  It defaults to serverPort.
74	connectPort int
75	// tls, if set to true, sets up TLS frontend connection.
76	tls bool
77	// tlsConfig is the client side TLS configuration that is used
78	// when tls is true.
79	tlsConfig *tls.Config
80	// tcpData is additional data that are written to connection
81	// before TLS handshake starts.  This field is ignored if tls
82	// is false.
83	tcpData []byte
84	// quic, if set to true, sets up QUIC frontend connection.
85	// quic implies tls = true.
86	quic bool
87}
88
89// newServerTester creates test context.
90func newServerTester(t *testing.T, opts options) *serverTester {
91	if opts.quic {
92		opts.tls = true
93	}
94
95	if opts.handler == nil {
96		opts.handler = noopHandler
97	}
98	if opts.connectPort == 0 {
99		opts.connectPort = serverPort
100	}
101
102	ts := httptest.NewUnstartedServer(opts.handler)
103
104	var args []string
105	var backendTLS, dns, externalDNS, acceptProxyProtocol, redirectIfNotTLS, affinityCookie, alpnH1 bool
106
107	for _, k := range opts.args {
108		switch k {
109		case "--http2-bridge":
110			backendTLS = true
111		case "--dns":
112			dns = true
113		case "--external-dns":
114			dns = true
115			externalDNS = true
116		case "--accept-proxy-protocol":
117			acceptProxyProtocol = true
118		case "--redirect-if-not-tls":
119			redirectIfNotTLS = true
120		case "--affinity-cookie":
121			affinityCookie = true
122		case "--alpn-h1":
123			alpnH1 = true
124		default:
125			args = append(args, k)
126		}
127	}
128	if backendTLS {
129		nghttp2.ConfigureServer(ts.Config, &nghttp2.Server{})
130		// According to httptest/server.go, we have to set
131		// NextProtos separately for ts.TLS.  NextProtos set
132		// in nghttp2.ConfigureServer is effectively ignored.
133		ts.TLS = new(tls.Config)
134		ts.TLS.NextProtos = append(ts.TLS.NextProtos, "h2")
135		ts.StartTLS()
136		args = append(args, "-k")
137	} else {
138		ts.Start()
139	}
140	scheme := "http"
141	if opts.tls {
142		scheme = "https"
143		args = append(args, testDir+"/server.key", testDir+"/server.crt")
144	}
145
146	backendURL, err := url.Parse(ts.URL)
147	if err != nil {
148		t.Fatalf("Error parsing URL from httptest.Server: %v", err)
149	}
150
151	// URL.Host looks like "127.0.0.1:8080", but we want
152	// "127.0.0.1,8080"
153	b := "-b"
154	if !externalDNS {
155		b += fmt.Sprintf("%v;", strings.Replace(backendURL.Host, ":", ",", -1))
156	} else {
157		sep := strings.LastIndex(backendURL.Host, ":")
158		if sep == -1 {
159			t.Fatalf("backendURL.Host %v does not contain separator ':'", backendURL.Host)
160		}
161		// We use awesome service nip.io.
162		b += fmt.Sprintf("%v.nip.io,%v;", backendURL.Host[:sep], backendURL.Host[sep+1:])
163	}
164
165	if backendTLS {
166		b += ";proto=h2;tls"
167	}
168	if dns {
169		b += ";dns"
170	}
171
172	if redirectIfNotTLS {
173		b += ";redirect-if-not-tls"
174	}
175
176	if affinityCookie {
177		b += ";affinity=cookie;affinity-cookie-name=affinity;affinity-cookie-path=/foo/bar"
178	}
179
180	noTLS := ";no-tls"
181	if opts.tls {
182		noTLS = ""
183	}
184
185	var proxyProto string
186	if acceptProxyProtocol {
187		proxyProto = ";proxyproto"
188	}
189
190	args = append(args, fmt.Sprintf("-f127.0.0.1,%v%v%v", serverPort, noTLS, proxyProto), b,
191		"--errorlog-file="+logDir+"/log.txt", "-LINFO")
192
193	if opts.quic {
194		args = append(args,
195			fmt.Sprintf("-f127.0.0.1,%v;quic", serverPort),
196			"--no-quic-bpf")
197	}
198
199	authority := fmt.Sprintf("127.0.0.1:%v", opts.connectPort)
200
201	st := &serverTester{
202		cmd:          exec.Command(serverBin, args...),
203		t:            t,
204		ts:           ts,
205		url:          fmt.Sprintf("%v://%v", scheme, authority),
206		frontendHost: fmt.Sprintf("127.0.0.1:%v", serverPort),
207		backendHost:  backendURL.Host,
208		nextStreamID: 1,
209		authority:    authority,
210		frCh:         make(chan http2.Frame),
211		errCh:        make(chan error),
212	}
213
214	st.cmd.Stdout = os.Stdout
215	st.cmd.Stderr = os.Stderr
216
217	if err := st.cmd.Start(); err != nil {
218		st.t.Fatalf("Error starting %v: %v", serverBin, err)
219	}
220
221	retry := 0
222	for {
223		time.Sleep(50 * time.Millisecond)
224
225		conn, err := net.Dial("tcp", authority)
226		if err == nil && opts.tls {
227			if len(opts.tcpData) > 0 {
228				if _, err := conn.Write(opts.tcpData); err != nil {
229					st.Close()
230					st.t.Fatal("Error writing TCP data")
231				}
232			}
233
234			var tlsConfig *tls.Config
235			if opts.tlsConfig == nil {
236				tlsConfig = new(tls.Config)
237			} else {
238				tlsConfig = opts.tlsConfig.Clone()
239			}
240			tlsConfig.InsecureSkipVerify = true
241			if alpnH1 {
242				tlsConfig.NextProtos = []string{"http/1.1"}
243			} else {
244				tlsConfig.NextProtos = []string{"h2"}
245			}
246			tlsConn := tls.Client(conn, tlsConfig)
247			err = tlsConn.Handshake()
248			if err == nil {
249				conn = tlsConn
250			}
251		}
252		if err != nil {
253			retry++
254			if retry >= 100 {
255				st.Close()
256				st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
257			}
258			continue
259		}
260		st.conn = conn
261		break
262	}
263
264	st.fr = http2.NewFramer(st.conn, st.conn)
265	st.enc = hpack.NewEncoder(&st.headerBlkBuf)
266	st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
267		st.header.Add(f.Name, f.Value)
268	})
269
270	return st
271}
272
273func (st *serverTester) Close() {
274	if st.conn != nil {
275		st.conn.Close()
276	}
277	if st.cmd != nil {
278		done := make(chan struct{})
279		go func() {
280			if err := st.cmd.Wait(); err != nil {
281				st.t.Errorf("Error st.cmd.Wait() = %v", err)
282			}
283			close(done)
284		}()
285
286		if err := st.cmd.Process.Signal(syscall.SIGQUIT); err != nil {
287			st.t.Errorf("Error st.cmd.Process.Signal() = %v", err)
288		}
289
290		select {
291		case <-done:
292		case <-time.After(10 * time.Second):
293			if err := st.cmd.Process.Kill(); err != nil {
294				st.t.Errorf("Error st.cmd.Process.Kill() = %v", err)
295			}
296			<-done
297		}
298	}
299	if st.ts != nil {
300		st.ts.Close()
301	}
302}
303
304func (st *serverTester) readFrame() (http2.Frame, error) {
305	go func() {
306		f, err := st.fr.ReadFrame()
307		if err != nil {
308			st.errCh <- err
309			return
310		}
311		st.frCh <- f
312	}()
313
314	select {
315	case f := <-st.frCh:
316		return f, nil
317	case err := <-st.errCh:
318		return nil, err
319	case <-time.After(5 * time.Second):
320		return nil, errors.New("timeout waiting for frame")
321	}
322}
323
324type requestParam struct {
325	name        string              // name for this request to identify the request in log easily
326	streamID    uint32              // stream ID, automatically assigned if 0
327	method      string              // method, defaults to GET
328	scheme      string              // scheme, defaults to http
329	authority   string              // authority, defaults to backend server address
330	path        string              // path, defaults to /
331	header      []hpack.HeaderField // additional request header fields
332	body        []byte              // request body
333	trailer     []hpack.HeaderField // trailer part
334	httpUpgrade bool                // true if upgraded to HTTP/2 through HTTP Upgrade
335	noEndStream bool                // true if END_STREAM should not be sent
336}
337
338// wrapper for request body to set trailer part
339type chunkedBodyReader struct {
340	trailer        []hpack.HeaderField
341	trailerWritten bool
342	body           io.Reader
343	req            *http.Request
344}
345
346func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
347	// document says that we have to set http.Request.Trailer
348	// after request was sent and before body returns EOF.
349	if !cbr.trailerWritten {
350		cbr.trailerWritten = true
351		for _, h := range cbr.trailer {
352			cbr.req.Trailer.Set(h.Name, h.Value)
353		}
354	}
355	return cbr.body.Read(p)
356}
357
358func (st *serverTester) websocket(rp requestParam) *serverResponse {
359	urlstring := st.url + "/echo"
360
361	config, err := websocket.NewConfig(urlstring, st.url)
362	if err != nil {
363		st.t.Fatalf("websocket.NewConfig(%q, %q) returned error: %v", urlstring, st.url, err)
364	}
365
366	config.Header.Add("Test-Case", rp.name)
367	for _, h := range rp.header {
368		config.Header.Add(h.Name, h.Value)
369	}
370
371	ws, err := websocket.NewClient(config, st.conn)
372	if err != nil {
373		st.t.Fatalf("Error creating websocket client: %v", err)
374	}
375
376	if _, err := ws.Write(rp.body); err != nil {
377		st.t.Fatalf("ws.Write() returned error: %v", err)
378	}
379
380	msg := make([]byte, 1024)
381	var n int
382	if n, err = ws.Read(msg); err != nil {
383		st.t.Fatalf("ws.Read() returned error: %v", err)
384	}
385
386	res := &serverResponse{
387		body: msg[:n],
388	}
389
390	return res
391}
392
393func (st *serverTester) http3(rp requestParam) (*serverResponse, error) {
394	rt := &http3.RoundTripper{
395		TLSClientConfig: &tls.Config{
396			InsecureSkipVerify: true,
397		},
398	}
399
400	defer rt.Close()
401
402	c := &http.Client{
403		Transport: rt,
404	}
405
406	method := "GET"
407	if rp.method != "" {
408		method = rp.method
409	}
410
411	var body io.Reader
412
413	if rp.body != nil {
414		body = bytes.NewBuffer(rp.body)
415	}
416
417	reqURL := st.url
418
419	if rp.path != "" {
420		u, err := url.Parse(st.url)
421		if err != nil {
422			st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err)
423		}
424		u.Path = ""
425		u.RawQuery = ""
426		reqURL = u.String() + rp.path
427	}
428
429	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
430	defer cancel()
431
432	req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
433	if err != nil {
434		return nil, err
435	}
436
437	for _, h := range rp.header {
438		req.Header.Add(h.Name, h.Value)
439	}
440
441	req.Header.Add("Test-Case", rp.name)
442
443	// TODO http3 package does not support trailer at the time of
444	// this writing.
445
446	resp, err := c.Do(req)
447	if err != nil {
448		return nil, err
449	}
450
451	defer resp.Body.Close()
452
453	respBody, err := io.ReadAll(resp.Body)
454	if err != nil {
455		return nil, err
456	}
457
458	res := &serverResponse{
459		status:    resp.StatusCode,
460		header:    resp.Header,
461		body:      respBody,
462		connClose: resp.Close,
463	}
464
465	return res, nil
466}
467
468func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
469	method := "GET"
470	if rp.method != "" {
471		method = rp.method
472	}
473
474	var body io.Reader
475	var cbr *chunkedBodyReader
476	if rp.body != nil {
477		body = bytes.NewBuffer(rp.body)
478		if len(rp.trailer) != 0 {
479			cbr = &chunkedBodyReader{
480				trailer: rp.trailer,
481				body:    body,
482			}
483			body = cbr
484		}
485	}
486
487	reqURL := st.url
488
489	if rp.path != "" {
490		u, err := url.Parse(st.url)
491		if err != nil {
492			st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err)
493		}
494		u.Path = ""
495		u.RawQuery = ""
496		reqURL = u.String() + rp.path
497	}
498
499	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
500	defer cancel()
501
502	req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
503	if err != nil {
504		return nil, err
505	}
506	for _, h := range rp.header {
507		req.Header.Add(h.Name, h.Value)
508	}
509	req.Header.Add("Test-Case", rp.name)
510	if cbr != nil {
511		cbr.req = req
512		// this makes request use chunked encoding
513		req.ContentLength = -1
514		req.Trailer = make(http.Header)
515		for _, h := range cbr.trailer {
516			req.Trailer.Set(h.Name, "")
517		}
518	}
519	if err := req.Write(st.conn); err != nil {
520		return nil, err
521	}
522	resp, err := http.ReadResponse(bufio.NewReader(st.conn), req)
523	if err != nil {
524		return nil, err
525	}
526	respBody, err := io.ReadAll(resp.Body)
527	if err != nil {
528		return nil, err
529	}
530	resp.Body.Close()
531
532	res := &serverResponse{
533		status:    resp.StatusCode,
534		header:    resp.Header,
535		body:      respBody,
536		connClose: resp.Close,
537	}
538
539	return res, nil
540}
541
542func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
543	st.headerBlkBuf.Reset()
544	st.header = make(http.Header)
545
546	var id uint32
547	if rp.streamID != 0 {
548		id = rp.streamID
549		if id >= st.nextStreamID && id%2 == 1 {
550			st.nextStreamID = id + 2
551		}
552	} else {
553		id = st.nextStreamID
554		st.nextStreamID += 2
555	}
556
557	if !st.h2PrefaceSent {
558		st.h2PrefaceSent = true
559		fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
560		if err := st.fr.WriteSettings(); err != nil {
561			return nil, err
562		}
563	}
564
565	res := &serverResponse{
566		streamID: id,
567	}
568
569	streams := make(map[uint32]*serverResponse)
570	streams[id] = res
571
572	if !rp.httpUpgrade {
573		method := "GET"
574		if rp.method != "" {
575			method = rp.method
576		}
577		_ = st.enc.WriteField(pair(":method", method))
578
579		scheme := "http"
580		if rp.scheme != "" {
581			scheme = rp.scheme
582		}
583		_ = st.enc.WriteField(pair(":scheme", scheme))
584
585		authority := st.authority
586		if rp.authority != "" {
587			authority = rp.authority
588		}
589		_ = st.enc.WriteField(pair(":authority", authority))
590
591		path := "/"
592		if rp.path != "" {
593			path = rp.path
594		}
595		_ = st.enc.WriteField(pair(":path", path))
596
597		_ = st.enc.WriteField(pair("test-case", rp.name))
598
599		for _, h := range rp.header {
600			_ = st.enc.WriteField(h)
601		}
602
603		err := st.fr.WriteHeaders(http2.HeadersFrameParam{
604			StreamID:      id,
605			EndStream:     len(rp.body) == 0 && len(rp.trailer) == 0 && !rp.noEndStream,
606			EndHeaders:    true,
607			BlockFragment: st.headerBlkBuf.Bytes(),
608		})
609		if err != nil {
610			return nil, err
611		}
612
613		if len(rp.body) != 0 {
614			// TODO we assume rp.body fits in 1 frame
615			if err := st.fr.WriteData(id, len(rp.trailer) == 0 && !rp.noEndStream, rp.body); err != nil {
616				return nil, err
617			}
618		}
619
620		if len(rp.trailer) != 0 {
621			st.headerBlkBuf.Reset()
622			for _, h := range rp.trailer {
623				_ = st.enc.WriteField(h)
624			}
625			err := st.fr.WriteHeaders(http2.HeadersFrameParam{
626				StreamID:      id,
627				EndStream:     true,
628				EndHeaders:    true,
629				BlockFragment: st.headerBlkBuf.Bytes(),
630			})
631			if err != nil {
632				return nil, err
633			}
634		}
635	}
636loop:
637	for {
638		fr, err := st.readFrame()
639		if err != nil {
640			return res, err
641		}
642		switch f := fr.(type) {
643		case *http2.HeadersFrame:
644			_, err := st.dec.Write(f.HeaderBlockFragment())
645			if err != nil {
646				return res, err
647			}
648			sr, ok := streams[f.FrameHeader.StreamID]
649			if !ok {
650				st.header = make(http.Header)
651				break
652			}
653			sr.header = cloneHeader(st.header)
654			var status int
655			status, err = strconv.Atoi(sr.header.Get(":status"))
656			if err != nil {
657				return res, fmt.Errorf("Error parsing status code: %w", err)
658			}
659			sr.status = status
660			if f.StreamEnded() {
661				if streamEnded(res, streams, sr) {
662					break loop
663				}
664			}
665		case *http2.PushPromiseFrame:
666			_, err := st.dec.Write(f.HeaderBlockFragment())
667			if err != nil {
668				return res, err
669			}
670			sr := &serverResponse{
671				streamID:  f.PromiseID,
672				reqHeader: cloneHeader(st.header),
673			}
674			streams[sr.streamID] = sr
675		case *http2.DataFrame:
676			sr, ok := streams[f.FrameHeader.StreamID]
677			if !ok {
678				break
679			}
680			sr.body = append(sr.body, f.Data()...)
681			if f.StreamEnded() {
682				if streamEnded(res, streams, sr) {
683					break loop
684				}
685			}
686		case *http2.RSTStreamFrame:
687			sr, ok := streams[f.FrameHeader.StreamID]
688			if !ok {
689				break
690			}
691			sr.errCode = f.ErrCode
692			if streamEnded(res, streams, sr) {
693				break loop
694			}
695		case *http2.GoAwayFrame:
696			if f.ErrCode == http2.ErrCodeNo {
697				break
698			}
699			res.errCode = f.ErrCode
700			res.connErr = true
701			break loop
702		case *http2.SettingsFrame:
703			if f.IsAck() {
704				break
705			}
706			if err := st.fr.WriteSettingsAck(); err != nil {
707				return res, err
708			}
709		}
710	}
711	sort.Sort(ByStreamID(res.pushResponse))
712	return res, nil
713}
714
715func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool {
716	delete(streams, sr.streamID)
717	if mainSr.streamID != sr.streamID {
718		mainSr.pushResponse = append(mainSr.pushResponse, sr)
719	}
720	return len(streams) == 0
721}
722
723type serverResponse struct {
724	status       int               // HTTP status code
725	header       http.Header       // response header fields
726	body         []byte            // response body
727	streamID     uint32            // stream ID in HTTP/2
728	errCode      http2.ErrCode     // error code received in HTTP/2 RST_STREAM or GOAWAY
729	connErr      bool              // true if HTTP/2 connection error
730	connClose    bool              // Connection: close is included in response header in HTTP/1 test
731	reqHeader    http.Header       // http request header, currently only sotres pushed request header
732	pushResponse []*serverResponse // pushed response
733}
734
735type ByStreamID []*serverResponse
736
737func (b ByStreamID) Len() int {
738	return len(b)
739}
740
741func (b ByStreamID) Swap(i, j int) {
742	b[i], b[j] = b[j], b[i]
743}
744
745func (b ByStreamID) Less(i, j int) bool {
746	return b[i].streamID < b[j].streamID
747}
748
749func cloneHeader(h http.Header) http.Header {
750	h2 := make(http.Header, len(h))
751	for k, vv := range h {
752		vv2 := make([]string, len(vv))
753		copy(vv2, vv)
754		h2[k] = vv2
755	}
756	return h2
757}
758
759func noopHandler(w http.ResponseWriter, r *http.Request) {
760	if _, err := io.ReadAll(r.Body); err != nil {
761		http.Error(w, fmt.Sprintf("Error io.ReadAll() = %v", err), http.StatusInternalServerError)
762	}
763}
764
765type APIResponse struct {
766	Status string                 `json:"status,omitempty"`
767	Code   int                    `json:"code,omitempty"`
768	Data   map[string]interface{} `json:"data,omitempty"`
769}
770
771type proxyProtocolV2 struct {
772	command            proxyProtocolV2Command
773	sourceAddress      net.Addr
774	destinationAddress net.Addr
775	additionalData     []byte
776}
777
778type proxyProtocolV2Command int
779
780const (
781	proxyProtocolV2CommandLocal proxyProtocolV2Command = 0x0
782	proxyProtocolV2CommandProxy proxyProtocolV2Command = 0x1
783)
784
785type proxyProtocolV2Family int
786
787const (
788	proxyProtocolV2FamilyUnspec proxyProtocolV2Family = 0x0
789	proxyProtocolV2FamilyInet   proxyProtocolV2Family = 0x1
790	proxyProtocolV2FamilyInet6  proxyProtocolV2Family = 0x2
791	proxyProtocolV2FamilyUnix   proxyProtocolV2Family = 0x3
792)
793
794type proxyProtocolV2Protocol int
795
796const (
797	proxyProtocolV2ProtocolUnspec proxyProtocolV2Protocol = 0x0
798	proxyProtocolV2ProtocolStream proxyProtocolV2Protocol = 0x1
799	proxyProtocolV2ProtocolDgram  proxyProtocolV2Protocol = 0x2
800)
801
802func writeProxyProtocolV2(w io.Writer, hdr proxyProtocolV2) error {
803	if _, err := w.Write([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}); err != nil {
804		return err
805	}
806	if _, err := w.Write([]byte{byte(0x20 | hdr.command)}); err != nil {
807		return err
808	}
809
810	switch srcAddr := hdr.sourceAddress.(type) {
811	case *net.TCPAddr:
812		dstAddr := hdr.destinationAddress.(*net.TCPAddr)
813		if len(srcAddr.IP) != len(dstAddr.IP) {
814			panic("len(srcAddr.IP) != len(dstAddr.IP)")
815		}
816		var fam byte
817		if len(srcAddr.IP) == 4 {
818			fam = byte(proxyProtocolV2FamilyInet << 4)
819		} else {
820			fam = byte(proxyProtocolV2FamilyInet6 << 4)
821		}
822		fam |= byte(proxyProtocolV2ProtocolStream)
823		if _, err := w.Write([]byte{fam}); err != nil {
824			return err
825		}
826		length := uint16(len(srcAddr.IP)*2 + 4 + len(hdr.additionalData))
827		if err := binary.Write(w, binary.BigEndian, length); err != nil {
828			return err
829		}
830		if _, err := w.Write(srcAddr.IP); err != nil {
831			return err
832		}
833		if _, err := w.Write(dstAddr.IP); err != nil {
834			return err
835		}
836		if err := binary.Write(w, binary.BigEndian, uint16(srcAddr.Port)); err != nil {
837			return err
838		}
839		if err := binary.Write(w, binary.BigEndian, uint16(dstAddr.Port)); err != nil {
840			return err
841		}
842	case *net.UnixAddr:
843		dstAddr := hdr.destinationAddress.(*net.UnixAddr)
844		if len(srcAddr.Name) > 108 {
845			panic("too long Unix source address")
846		}
847		if len(dstAddr.Name) > 108 {
848			panic("too long Unix destination address")
849		}
850		fam := byte(proxyProtocolV2FamilyUnix << 4)
851		switch srcAddr.Net {
852		case "unix":
853			fam |= byte(proxyProtocolV2ProtocolStream)
854		case "unixdgram":
855			fam |= byte(proxyProtocolV2ProtocolDgram)
856		default:
857			fam |= byte(proxyProtocolV2ProtocolUnspec)
858		}
859		if _, err := w.Write([]byte{fam}); err != nil {
860			return err
861		}
862		length := uint16(216 + len(hdr.additionalData))
863		if err := binary.Write(w, binary.BigEndian, length); err != nil {
864			return err
865		}
866		zeros := make([]byte, 108)
867		if _, err := w.Write([]byte(srcAddr.Name)); err != nil {
868			return err
869		}
870		if _, err := w.Write(zeros[:108-len(srcAddr.Name)]); err != nil {
871			return err
872		}
873		if _, err := w.Write([]byte(dstAddr.Name)); err != nil {
874			return err
875		}
876		if _, err := w.Write(zeros[:108-len(dstAddr.Name)]); err != nil {
877			return err
878		}
879	default:
880		fam := byte(proxyProtocolV2FamilyUnspec<<4) | byte(proxyProtocolV2ProtocolUnspec)
881		if _, err := w.Write([]byte{fam}); err != nil {
882			return err
883		}
884		length := uint16(len(hdr.additionalData))
885		if err := binary.Write(w, binary.BigEndian, length); err != nil {
886			return err
887		}
888	}
889
890	if _, err := w.Write(hdr.additionalData); err != nil {
891		return err
892	}
893
894	return nil
895}
896