1package nghttp2 2 3import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/tls" 8 "encoding/binary" 9 "errors" 10 "fmt" 11 "io" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "os" 17 "os/exec" 18 "sort" 19 "strconv" 20 "strings" 21 "syscall" 22 "testing" 23 "time" 24 25 "github.com/lucas-clemente/quic-go/http3" 26 "github.com/tatsuhiro-t/go-nghttp2" 27 "golang.org/x/net/http2" 28 "golang.org/x/net/http2/hpack" 29 "golang.org/x/net/websocket" 30) 31 32const ( 33 serverBin = buildDir + "/src/nghttpx" 34 serverPort = 3009 35 testDir = sourceDir + "/integration-tests" 36 logDir = buildDir + "/integration-tests" 37) 38 39func pair(name, value string) hpack.HeaderField { 40 return hpack.HeaderField{ 41 Name: name, 42 Value: value, 43 } 44} 45 46type serverTester struct { 47 cmd *exec.Cmd // test frontend server process, which is test subject 48 url string // test frontend server URL 49 t *testing.T 50 ts *httptest.Server // backend server 51 frontendHost string // frontend server host 52 backendHost string // backend server host 53 conn net.Conn // connection to frontend server 54 h2PrefaceSent bool // HTTP/2 preface was sent in conn 55 nextStreamID uint32 // next stream ID 56 fr *http2.Framer // HTTP/2 framer 57 headerBlkBuf bytes.Buffer // buffer to store encoded header block 58 enc *hpack.Encoder // HTTP/2 HPACK encoder 59 header http.Header // received header fields 60 dec *hpack.Decoder // HTTP/2 HPACK decoder 61 authority string // server's host:port 62 frCh chan http2.Frame // used for incoming HTTP/2 frame 63 errCh chan error 64} 65 66type options struct { 67 // args is the additional arguments to nghttpx. 68 args []string 69 // handler is the handler to handle the request. It defaults 70 // to noopHandler. 71 handler http.HandlerFunc 72 // connectPort is the server side port where client connection 73 // is made. It defaults to serverPort. 74 connectPort int 75 // tls, if set to true, sets up TLS frontend connection. 76 tls bool 77 // tlsConfig is the client side TLS configuration that is used 78 // when tls is true. 79 tlsConfig *tls.Config 80 // tcpData is additional data that are written to connection 81 // before TLS handshake starts. This field is ignored if tls 82 // is false. 83 tcpData []byte 84 // quic, if set to true, sets up QUIC frontend connection. 85 // quic implies tls = true. 86 quic bool 87} 88 89// newServerTester creates test context. 90func newServerTester(t *testing.T, opts options) *serverTester { 91 if opts.quic { 92 opts.tls = true 93 } 94 95 if opts.handler == nil { 96 opts.handler = noopHandler 97 } 98 if opts.connectPort == 0 { 99 opts.connectPort = serverPort 100 } 101 102 ts := httptest.NewUnstartedServer(opts.handler) 103 104 var args []string 105 var backendTLS, dns, externalDNS, acceptProxyProtocol, redirectIfNotTLS, affinityCookie, alpnH1 bool 106 107 for _, k := range opts.args { 108 switch k { 109 case "--http2-bridge": 110 backendTLS = true 111 case "--dns": 112 dns = true 113 case "--external-dns": 114 dns = true 115 externalDNS = true 116 case "--accept-proxy-protocol": 117 acceptProxyProtocol = true 118 case "--redirect-if-not-tls": 119 redirectIfNotTLS = true 120 case "--affinity-cookie": 121 affinityCookie = true 122 case "--alpn-h1": 123 alpnH1 = true 124 default: 125 args = append(args, k) 126 } 127 } 128 if backendTLS { 129 nghttp2.ConfigureServer(ts.Config, &nghttp2.Server{}) 130 // According to httptest/server.go, we have to set 131 // NextProtos separately for ts.TLS. NextProtos set 132 // in nghttp2.ConfigureServer is effectively ignored. 133 ts.TLS = new(tls.Config) 134 ts.TLS.NextProtos = append(ts.TLS.NextProtos, "h2") 135 ts.StartTLS() 136 args = append(args, "-k") 137 } else { 138 ts.Start() 139 } 140 scheme := "http" 141 if opts.tls { 142 scheme = "https" 143 args = append(args, testDir+"/server.key", testDir+"/server.crt") 144 } 145 146 backendURL, err := url.Parse(ts.URL) 147 if err != nil { 148 t.Fatalf("Error parsing URL from httptest.Server: %v", err) 149 } 150 151 // URL.Host looks like "127.0.0.1:8080", but we want 152 // "127.0.0.1,8080" 153 b := "-b" 154 if !externalDNS { 155 b += fmt.Sprintf("%v;", strings.Replace(backendURL.Host, ":", ",", -1)) 156 } else { 157 sep := strings.LastIndex(backendURL.Host, ":") 158 if sep == -1 { 159 t.Fatalf("backendURL.Host %v does not contain separator ':'", backendURL.Host) 160 } 161 // We use awesome service nip.io. 162 b += fmt.Sprintf("%v.nip.io,%v;", backendURL.Host[:sep], backendURL.Host[sep+1:]) 163 } 164 165 if backendTLS { 166 b += ";proto=h2;tls" 167 } 168 if dns { 169 b += ";dns" 170 } 171 172 if redirectIfNotTLS { 173 b += ";redirect-if-not-tls" 174 } 175 176 if affinityCookie { 177 b += ";affinity=cookie;affinity-cookie-name=affinity;affinity-cookie-path=/foo/bar" 178 } 179 180 noTLS := ";no-tls" 181 if opts.tls { 182 noTLS = "" 183 } 184 185 var proxyProto string 186 if acceptProxyProtocol { 187 proxyProto = ";proxyproto" 188 } 189 190 args = append(args, fmt.Sprintf("-f127.0.0.1,%v%v%v", serverPort, noTLS, proxyProto), b, 191 "--errorlog-file="+logDir+"/log.txt", "-LINFO") 192 193 if opts.quic { 194 args = append(args, 195 fmt.Sprintf("-f127.0.0.1,%v;quic", serverPort), 196 "--no-quic-bpf") 197 } 198 199 authority := fmt.Sprintf("127.0.0.1:%v", opts.connectPort) 200 201 st := &serverTester{ 202 cmd: exec.Command(serverBin, args...), 203 t: t, 204 ts: ts, 205 url: fmt.Sprintf("%v://%v", scheme, authority), 206 frontendHost: fmt.Sprintf("127.0.0.1:%v", serverPort), 207 backendHost: backendURL.Host, 208 nextStreamID: 1, 209 authority: authority, 210 frCh: make(chan http2.Frame), 211 errCh: make(chan error), 212 } 213 214 st.cmd.Stdout = os.Stdout 215 st.cmd.Stderr = os.Stderr 216 217 if err := st.cmd.Start(); err != nil { 218 st.t.Fatalf("Error starting %v: %v", serverBin, err) 219 } 220 221 retry := 0 222 for { 223 time.Sleep(50 * time.Millisecond) 224 225 conn, err := net.Dial("tcp", authority) 226 if err == nil && opts.tls { 227 if len(opts.tcpData) > 0 { 228 if _, err := conn.Write(opts.tcpData); err != nil { 229 st.Close() 230 st.t.Fatal("Error writing TCP data") 231 } 232 } 233 234 var tlsConfig *tls.Config 235 if opts.tlsConfig == nil { 236 tlsConfig = new(tls.Config) 237 } else { 238 tlsConfig = opts.tlsConfig.Clone() 239 } 240 tlsConfig.InsecureSkipVerify = true 241 if alpnH1 { 242 tlsConfig.NextProtos = []string{"http/1.1"} 243 } else { 244 tlsConfig.NextProtos = []string{"h2"} 245 } 246 tlsConn := tls.Client(conn, tlsConfig) 247 err = tlsConn.Handshake() 248 if err == nil { 249 conn = tlsConn 250 } 251 } 252 if err != nil { 253 retry++ 254 if retry >= 100 { 255 st.Close() 256 st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid") 257 } 258 continue 259 } 260 st.conn = conn 261 break 262 } 263 264 st.fr = http2.NewFramer(st.conn, st.conn) 265 st.enc = hpack.NewEncoder(&st.headerBlkBuf) 266 st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) { 267 st.header.Add(f.Name, f.Value) 268 }) 269 270 return st 271} 272 273func (st *serverTester) Close() { 274 if st.conn != nil { 275 st.conn.Close() 276 } 277 if st.cmd != nil { 278 done := make(chan struct{}) 279 go func() { 280 if err := st.cmd.Wait(); err != nil { 281 st.t.Errorf("Error st.cmd.Wait() = %v", err) 282 } 283 close(done) 284 }() 285 286 if err := st.cmd.Process.Signal(syscall.SIGQUIT); err != nil { 287 st.t.Errorf("Error st.cmd.Process.Signal() = %v", err) 288 } 289 290 select { 291 case <-done: 292 case <-time.After(10 * time.Second): 293 if err := st.cmd.Process.Kill(); err != nil { 294 st.t.Errorf("Error st.cmd.Process.Kill() = %v", err) 295 } 296 <-done 297 } 298 } 299 if st.ts != nil { 300 st.ts.Close() 301 } 302} 303 304func (st *serverTester) readFrame() (http2.Frame, error) { 305 go func() { 306 f, err := st.fr.ReadFrame() 307 if err != nil { 308 st.errCh <- err 309 return 310 } 311 st.frCh <- f 312 }() 313 314 select { 315 case f := <-st.frCh: 316 return f, nil 317 case err := <-st.errCh: 318 return nil, err 319 case <-time.After(5 * time.Second): 320 return nil, errors.New("timeout waiting for frame") 321 } 322} 323 324type requestParam struct { 325 name string // name for this request to identify the request in log easily 326 streamID uint32 // stream ID, automatically assigned if 0 327 method string // method, defaults to GET 328 scheme string // scheme, defaults to http 329 authority string // authority, defaults to backend server address 330 path string // path, defaults to / 331 header []hpack.HeaderField // additional request header fields 332 body []byte // request body 333 trailer []hpack.HeaderField // trailer part 334 httpUpgrade bool // true if upgraded to HTTP/2 through HTTP Upgrade 335 noEndStream bool // true if END_STREAM should not be sent 336} 337 338// wrapper for request body to set trailer part 339type chunkedBodyReader struct { 340 trailer []hpack.HeaderField 341 trailerWritten bool 342 body io.Reader 343 req *http.Request 344} 345 346func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) { 347 // document says that we have to set http.Request.Trailer 348 // after request was sent and before body returns EOF. 349 if !cbr.trailerWritten { 350 cbr.trailerWritten = true 351 for _, h := range cbr.trailer { 352 cbr.req.Trailer.Set(h.Name, h.Value) 353 } 354 } 355 return cbr.body.Read(p) 356} 357 358func (st *serverTester) websocket(rp requestParam) *serverResponse { 359 urlstring := st.url + "/echo" 360 361 config, err := websocket.NewConfig(urlstring, st.url) 362 if err != nil { 363 st.t.Fatalf("websocket.NewConfig(%q, %q) returned error: %v", urlstring, st.url, err) 364 } 365 366 config.Header.Add("Test-Case", rp.name) 367 for _, h := range rp.header { 368 config.Header.Add(h.Name, h.Value) 369 } 370 371 ws, err := websocket.NewClient(config, st.conn) 372 if err != nil { 373 st.t.Fatalf("Error creating websocket client: %v", err) 374 } 375 376 if _, err := ws.Write(rp.body); err != nil { 377 st.t.Fatalf("ws.Write() returned error: %v", err) 378 } 379 380 msg := make([]byte, 1024) 381 var n int 382 if n, err = ws.Read(msg); err != nil { 383 st.t.Fatalf("ws.Read() returned error: %v", err) 384 } 385 386 res := &serverResponse{ 387 body: msg[:n], 388 } 389 390 return res 391} 392 393func (st *serverTester) http3(rp requestParam) (*serverResponse, error) { 394 rt := &http3.RoundTripper{ 395 TLSClientConfig: &tls.Config{ 396 InsecureSkipVerify: true, 397 }, 398 } 399 400 defer rt.Close() 401 402 c := &http.Client{ 403 Transport: rt, 404 } 405 406 method := "GET" 407 if rp.method != "" { 408 method = rp.method 409 } 410 411 var body io.Reader 412 413 if rp.body != nil { 414 body = bytes.NewBuffer(rp.body) 415 } 416 417 reqURL := st.url 418 419 if rp.path != "" { 420 u, err := url.Parse(st.url) 421 if err != nil { 422 st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err) 423 } 424 u.Path = "" 425 u.RawQuery = "" 426 reqURL = u.String() + rp.path 427 } 428 429 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 430 defer cancel() 431 432 req, err := http.NewRequestWithContext(ctx, method, reqURL, body) 433 if err != nil { 434 return nil, err 435 } 436 437 for _, h := range rp.header { 438 req.Header.Add(h.Name, h.Value) 439 } 440 441 req.Header.Add("Test-Case", rp.name) 442 443 // TODO http3 package does not support trailer at the time of 444 // this writing. 445 446 resp, err := c.Do(req) 447 if err != nil { 448 return nil, err 449 } 450 451 defer resp.Body.Close() 452 453 respBody, err := io.ReadAll(resp.Body) 454 if err != nil { 455 return nil, err 456 } 457 458 res := &serverResponse{ 459 status: resp.StatusCode, 460 header: resp.Header, 461 body: respBody, 462 connClose: resp.Close, 463 } 464 465 return res, nil 466} 467 468func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { 469 method := "GET" 470 if rp.method != "" { 471 method = rp.method 472 } 473 474 var body io.Reader 475 var cbr *chunkedBodyReader 476 if rp.body != nil { 477 body = bytes.NewBuffer(rp.body) 478 if len(rp.trailer) != 0 { 479 cbr = &chunkedBodyReader{ 480 trailer: rp.trailer, 481 body: body, 482 } 483 body = cbr 484 } 485 } 486 487 reqURL := st.url 488 489 if rp.path != "" { 490 u, err := url.Parse(st.url) 491 if err != nil { 492 st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err) 493 } 494 u.Path = "" 495 u.RawQuery = "" 496 reqURL = u.String() + rp.path 497 } 498 499 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 500 defer cancel() 501 502 req, err := http.NewRequestWithContext(ctx, method, reqURL, body) 503 if err != nil { 504 return nil, err 505 } 506 for _, h := range rp.header { 507 req.Header.Add(h.Name, h.Value) 508 } 509 req.Header.Add("Test-Case", rp.name) 510 if cbr != nil { 511 cbr.req = req 512 // this makes request use chunked encoding 513 req.ContentLength = -1 514 req.Trailer = make(http.Header) 515 for _, h := range cbr.trailer { 516 req.Trailer.Set(h.Name, "") 517 } 518 } 519 if err := req.Write(st.conn); err != nil { 520 return nil, err 521 } 522 resp, err := http.ReadResponse(bufio.NewReader(st.conn), req) 523 if err != nil { 524 return nil, err 525 } 526 respBody, err := io.ReadAll(resp.Body) 527 if err != nil { 528 return nil, err 529 } 530 resp.Body.Close() 531 532 res := &serverResponse{ 533 status: resp.StatusCode, 534 header: resp.Header, 535 body: respBody, 536 connClose: resp.Close, 537 } 538 539 return res, nil 540} 541 542func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { 543 st.headerBlkBuf.Reset() 544 st.header = make(http.Header) 545 546 var id uint32 547 if rp.streamID != 0 { 548 id = rp.streamID 549 if id >= st.nextStreamID && id%2 == 1 { 550 st.nextStreamID = id + 2 551 } 552 } else { 553 id = st.nextStreamID 554 st.nextStreamID += 2 555 } 556 557 if !st.h2PrefaceSent { 558 st.h2PrefaceSent = true 559 fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") 560 if err := st.fr.WriteSettings(); err != nil { 561 return nil, err 562 } 563 } 564 565 res := &serverResponse{ 566 streamID: id, 567 } 568 569 streams := make(map[uint32]*serverResponse) 570 streams[id] = res 571 572 if !rp.httpUpgrade { 573 method := "GET" 574 if rp.method != "" { 575 method = rp.method 576 } 577 _ = st.enc.WriteField(pair(":method", method)) 578 579 scheme := "http" 580 if rp.scheme != "" { 581 scheme = rp.scheme 582 } 583 _ = st.enc.WriteField(pair(":scheme", scheme)) 584 585 authority := st.authority 586 if rp.authority != "" { 587 authority = rp.authority 588 } 589 _ = st.enc.WriteField(pair(":authority", authority)) 590 591 path := "/" 592 if rp.path != "" { 593 path = rp.path 594 } 595 _ = st.enc.WriteField(pair(":path", path)) 596 597 _ = st.enc.WriteField(pair("test-case", rp.name)) 598 599 for _, h := range rp.header { 600 _ = st.enc.WriteField(h) 601 } 602 603 err := st.fr.WriteHeaders(http2.HeadersFrameParam{ 604 StreamID: id, 605 EndStream: len(rp.body) == 0 && len(rp.trailer) == 0 && !rp.noEndStream, 606 EndHeaders: true, 607 BlockFragment: st.headerBlkBuf.Bytes(), 608 }) 609 if err != nil { 610 return nil, err 611 } 612 613 if len(rp.body) != 0 { 614 // TODO we assume rp.body fits in 1 frame 615 if err := st.fr.WriteData(id, len(rp.trailer) == 0 && !rp.noEndStream, rp.body); err != nil { 616 return nil, err 617 } 618 } 619 620 if len(rp.trailer) != 0 { 621 st.headerBlkBuf.Reset() 622 for _, h := range rp.trailer { 623 _ = st.enc.WriteField(h) 624 } 625 err := st.fr.WriteHeaders(http2.HeadersFrameParam{ 626 StreamID: id, 627 EndStream: true, 628 EndHeaders: true, 629 BlockFragment: st.headerBlkBuf.Bytes(), 630 }) 631 if err != nil { 632 return nil, err 633 } 634 } 635 } 636loop: 637 for { 638 fr, err := st.readFrame() 639 if err != nil { 640 return res, err 641 } 642 switch f := fr.(type) { 643 case *http2.HeadersFrame: 644 _, err := st.dec.Write(f.HeaderBlockFragment()) 645 if err != nil { 646 return res, err 647 } 648 sr, ok := streams[f.FrameHeader.StreamID] 649 if !ok { 650 st.header = make(http.Header) 651 break 652 } 653 sr.header = cloneHeader(st.header) 654 var status int 655 status, err = strconv.Atoi(sr.header.Get(":status")) 656 if err != nil { 657 return res, fmt.Errorf("Error parsing status code: %w", err) 658 } 659 sr.status = status 660 if f.StreamEnded() { 661 if streamEnded(res, streams, sr) { 662 break loop 663 } 664 } 665 case *http2.PushPromiseFrame: 666 _, err := st.dec.Write(f.HeaderBlockFragment()) 667 if err != nil { 668 return res, err 669 } 670 sr := &serverResponse{ 671 streamID: f.PromiseID, 672 reqHeader: cloneHeader(st.header), 673 } 674 streams[sr.streamID] = sr 675 case *http2.DataFrame: 676 sr, ok := streams[f.FrameHeader.StreamID] 677 if !ok { 678 break 679 } 680 sr.body = append(sr.body, f.Data()...) 681 if f.StreamEnded() { 682 if streamEnded(res, streams, sr) { 683 break loop 684 } 685 } 686 case *http2.RSTStreamFrame: 687 sr, ok := streams[f.FrameHeader.StreamID] 688 if !ok { 689 break 690 } 691 sr.errCode = f.ErrCode 692 if streamEnded(res, streams, sr) { 693 break loop 694 } 695 case *http2.GoAwayFrame: 696 if f.ErrCode == http2.ErrCodeNo { 697 break 698 } 699 res.errCode = f.ErrCode 700 res.connErr = true 701 break loop 702 case *http2.SettingsFrame: 703 if f.IsAck() { 704 break 705 } 706 if err := st.fr.WriteSettingsAck(); err != nil { 707 return res, err 708 } 709 } 710 } 711 sort.Sort(ByStreamID(res.pushResponse)) 712 return res, nil 713} 714 715func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool { 716 delete(streams, sr.streamID) 717 if mainSr.streamID != sr.streamID { 718 mainSr.pushResponse = append(mainSr.pushResponse, sr) 719 } 720 return len(streams) == 0 721} 722 723type serverResponse struct { 724 status int // HTTP status code 725 header http.Header // response header fields 726 body []byte // response body 727 streamID uint32 // stream ID in HTTP/2 728 errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY 729 connErr bool // true if HTTP/2 connection error 730 connClose bool // Connection: close is included in response header in HTTP/1 test 731 reqHeader http.Header // http request header, currently only sotres pushed request header 732 pushResponse []*serverResponse // pushed response 733} 734 735type ByStreamID []*serverResponse 736 737func (b ByStreamID) Len() int { 738 return len(b) 739} 740 741func (b ByStreamID) Swap(i, j int) { 742 b[i], b[j] = b[j], b[i] 743} 744 745func (b ByStreamID) Less(i, j int) bool { 746 return b[i].streamID < b[j].streamID 747} 748 749func cloneHeader(h http.Header) http.Header { 750 h2 := make(http.Header, len(h)) 751 for k, vv := range h { 752 vv2 := make([]string, len(vv)) 753 copy(vv2, vv) 754 h2[k] = vv2 755 } 756 return h2 757} 758 759func noopHandler(w http.ResponseWriter, r *http.Request) { 760 if _, err := io.ReadAll(r.Body); err != nil { 761 http.Error(w, fmt.Sprintf("Error io.ReadAll() = %v", err), http.StatusInternalServerError) 762 } 763} 764 765type APIResponse struct { 766 Status string `json:"status,omitempty"` 767 Code int `json:"code,omitempty"` 768 Data map[string]interface{} `json:"data,omitempty"` 769} 770 771type proxyProtocolV2 struct { 772 command proxyProtocolV2Command 773 sourceAddress net.Addr 774 destinationAddress net.Addr 775 additionalData []byte 776} 777 778type proxyProtocolV2Command int 779 780const ( 781 proxyProtocolV2CommandLocal proxyProtocolV2Command = 0x0 782 proxyProtocolV2CommandProxy proxyProtocolV2Command = 0x1 783) 784 785type proxyProtocolV2Family int 786 787const ( 788 proxyProtocolV2FamilyUnspec proxyProtocolV2Family = 0x0 789 proxyProtocolV2FamilyInet proxyProtocolV2Family = 0x1 790 proxyProtocolV2FamilyInet6 proxyProtocolV2Family = 0x2 791 proxyProtocolV2FamilyUnix proxyProtocolV2Family = 0x3 792) 793 794type proxyProtocolV2Protocol int 795 796const ( 797 proxyProtocolV2ProtocolUnspec proxyProtocolV2Protocol = 0x0 798 proxyProtocolV2ProtocolStream proxyProtocolV2Protocol = 0x1 799 proxyProtocolV2ProtocolDgram proxyProtocolV2Protocol = 0x2 800) 801 802func writeProxyProtocolV2(w io.Writer, hdr proxyProtocolV2) error { 803 if _, err := w.Write([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}); err != nil { 804 return err 805 } 806 if _, err := w.Write([]byte{byte(0x20 | hdr.command)}); err != nil { 807 return err 808 } 809 810 switch srcAddr := hdr.sourceAddress.(type) { 811 case *net.TCPAddr: 812 dstAddr := hdr.destinationAddress.(*net.TCPAddr) 813 if len(srcAddr.IP) != len(dstAddr.IP) { 814 panic("len(srcAddr.IP) != len(dstAddr.IP)") 815 } 816 var fam byte 817 if len(srcAddr.IP) == 4 { 818 fam = byte(proxyProtocolV2FamilyInet << 4) 819 } else { 820 fam = byte(proxyProtocolV2FamilyInet6 << 4) 821 } 822 fam |= byte(proxyProtocolV2ProtocolStream) 823 if _, err := w.Write([]byte{fam}); err != nil { 824 return err 825 } 826 length := uint16(len(srcAddr.IP)*2 + 4 + len(hdr.additionalData)) 827 if err := binary.Write(w, binary.BigEndian, length); err != nil { 828 return err 829 } 830 if _, err := w.Write(srcAddr.IP); err != nil { 831 return err 832 } 833 if _, err := w.Write(dstAddr.IP); err != nil { 834 return err 835 } 836 if err := binary.Write(w, binary.BigEndian, uint16(srcAddr.Port)); err != nil { 837 return err 838 } 839 if err := binary.Write(w, binary.BigEndian, uint16(dstAddr.Port)); err != nil { 840 return err 841 } 842 case *net.UnixAddr: 843 dstAddr := hdr.destinationAddress.(*net.UnixAddr) 844 if len(srcAddr.Name) > 108 { 845 panic("too long Unix source address") 846 } 847 if len(dstAddr.Name) > 108 { 848 panic("too long Unix destination address") 849 } 850 fam := byte(proxyProtocolV2FamilyUnix << 4) 851 switch srcAddr.Net { 852 case "unix": 853 fam |= byte(proxyProtocolV2ProtocolStream) 854 case "unixdgram": 855 fam |= byte(proxyProtocolV2ProtocolDgram) 856 default: 857 fam |= byte(proxyProtocolV2ProtocolUnspec) 858 } 859 if _, err := w.Write([]byte{fam}); err != nil { 860 return err 861 } 862 length := uint16(216 + len(hdr.additionalData)) 863 if err := binary.Write(w, binary.BigEndian, length); err != nil { 864 return err 865 } 866 zeros := make([]byte, 108) 867 if _, err := w.Write([]byte(srcAddr.Name)); err != nil { 868 return err 869 } 870 if _, err := w.Write(zeros[:108-len(srcAddr.Name)]); err != nil { 871 return err 872 } 873 if _, err := w.Write([]byte(dstAddr.Name)); err != nil { 874 return err 875 } 876 if _, err := w.Write(zeros[:108-len(dstAddr.Name)]); err != nil { 877 return err 878 } 879 default: 880 fam := byte(proxyProtocolV2FamilyUnspec<<4) | byte(proxyProtocolV2ProtocolUnspec) 881 if _, err := w.Write([]byte{fam}); err != nil { 882 return err 883 } 884 length := uint16(len(hdr.additionalData)) 885 if err := binary.Write(w, binary.BigEndian, length); err != nil { 886 return err 887 } 888 } 889 890 if _, err := w.Write(hdr.additionalData); err != nil { 891 return err 892 } 893 894 return nil 895} 896