1// Copyright 2016 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// White-box tests for transport.go (in package http instead of http_test). 6 7package http 8 9import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "errors" 14 "io" 15 "net" 16 "net/http/internal/testcert" 17 "strings" 18 "testing" 19) 20 21// Issue 15446: incorrect wrapping of errors when server closes an idle connection. 22func TestTransportPersistConnReadLoopEOF(t *testing.T) { 23 ln := newLocalListener(t) 24 defer ln.Close() 25 26 connc := make(chan net.Conn, 1) 27 go func() { 28 defer close(connc) 29 c, err := ln.Accept() 30 if err != nil { 31 t.Error(err) 32 return 33 } 34 connc <- c 35 }() 36 37 tr := new(Transport) 38 req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) 39 req = req.WithT(t) 40 ctx, cancel := context.WithCancelCause(context.Background()) 41 treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel} 42 cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} 43 pc, err := tr.getConn(treq, cm) 44 if err != nil { 45 t.Fatal(err) 46 } 47 defer pc.close(errors.New("test over")) 48 49 conn := <-connc 50 if conn == nil { 51 // Already called t.Error in the accept goroutine. 52 return 53 } 54 conn.Close() // simulate the server hanging up on the client 55 56 _, err = pc.roundTrip(treq) 57 if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle { 58 t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err) 59 } 60 61 <-pc.closech 62 err = pc.closed 63 if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle { 64 t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError, or nothingWrittenError", err, err) 65 } 66} 67 68func isNothingWrittenError(err error) bool { 69 _, ok := err.(nothingWrittenError) 70 return ok 71} 72 73func isTransportReadFromServerError(err error) bool { 74 _, ok := err.(transportReadFromServerError) 75 return ok 76} 77 78func newLocalListener(t *testing.T) net.Listener { 79 ln, err := net.Listen("tcp", "127.0.0.1:0") 80 if err != nil { 81 ln, err = net.Listen("tcp6", "[::1]:0") 82 } 83 if err != nil { 84 t.Fatal(err) 85 } 86 return ln 87} 88 89func dummyRequest(method string) *Request { 90 req, err := NewRequest(method, "http://fake.tld/", nil) 91 if err != nil { 92 panic(err) 93 } 94 return req 95} 96func dummyRequestWithBody(method string) *Request { 97 req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) 98 if err != nil { 99 panic(err) 100 } 101 return req 102} 103 104func dummyRequestWithBodyNoGetBody(method string) *Request { 105 req := dummyRequestWithBody(method) 106 req.GetBody = nil 107 return req 108} 109 110// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn. 111type issue22091Error struct{} 112 113func (issue22091Error) IsHTTP2NoCachedConnError() {} 114func (issue22091Error) Error() string { return "issue22091Error" } 115 116func TestTransportShouldRetryRequest(t *testing.T) { 117 tests := []struct { 118 pc *persistConn 119 req *Request 120 121 err error 122 want bool 123 }{ 124 0: { 125 pc: &persistConn{reused: false}, 126 req: dummyRequest("POST"), 127 err: nothingWrittenError{}, 128 want: false, 129 }, 130 1: { 131 pc: &persistConn{reused: true}, 132 req: dummyRequest("POST"), 133 err: nothingWrittenError{}, 134 want: true, 135 }, 136 2: { 137 pc: &persistConn{reused: true}, 138 req: dummyRequest("POST"), 139 err: http2ErrNoCachedConn, 140 want: true, 141 }, 142 3: { 143 pc: nil, 144 req: nil, 145 err: issue22091Error{}, // like an external http2ErrNoCachedConn 146 want: true, 147 }, 148 4: { 149 pc: &persistConn{reused: true}, 150 req: dummyRequest("POST"), 151 err: errMissingHost, 152 want: false, 153 }, 154 5: { 155 pc: &persistConn{reused: true}, 156 req: dummyRequest("POST"), 157 err: transportReadFromServerError{}, 158 want: false, 159 }, 160 6: { 161 pc: &persistConn{reused: true}, 162 req: dummyRequest("GET"), 163 err: transportReadFromServerError{}, 164 want: true, 165 }, 166 7: { 167 pc: &persistConn{reused: true}, 168 req: dummyRequest("GET"), 169 err: errServerClosedIdle, 170 want: true, 171 }, 172 8: { 173 pc: &persistConn{reused: true}, 174 req: dummyRequestWithBody("POST"), 175 err: nothingWrittenError{}, 176 want: true, 177 }, 178 9: { 179 pc: &persistConn{reused: true}, 180 req: dummyRequestWithBodyNoGetBody("POST"), 181 err: nothingWrittenError{}, 182 want: false, 183 }, 184 } 185 for i, tt := range tests { 186 got := tt.pc.shouldRetryRequest(tt.req, tt.err) 187 if got != tt.want { 188 t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) 189 } 190 } 191} 192 193type roundTripFunc func(r *Request) (*Response, error) 194 195func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { 196 return f(r) 197} 198 199// Issue 25009 200func TestTransportBodyAltRewind(t *testing.T) { 201 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) 202 if err != nil { 203 t.Fatal(err) 204 } 205 ln := newLocalListener(t) 206 defer ln.Close() 207 208 go func() { 209 tln := tls.NewListener(ln, &tls.Config{ 210 NextProtos: []string{"foo"}, 211 Certificates: []tls.Certificate{cert}, 212 }) 213 for i := 0; i < 2; i++ { 214 sc, err := tln.Accept() 215 if err != nil { 216 t.Error(err) 217 return 218 } 219 if err := sc.(*tls.Conn).Handshake(); err != nil { 220 t.Error(err) 221 return 222 } 223 sc.Close() 224 } 225 }() 226 227 addr := ln.Addr().String() 228 req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) 229 roundTripped := false 230 tr := &Transport{ 231 DisableKeepAlives: true, 232 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ 233 "foo": func(authority string, c *tls.Conn) RoundTripper { 234 return roundTripFunc(func(r *Request) (*Response, error) { 235 n, _ := io.Copy(io.Discard, r.Body) 236 if n == 0 { 237 t.Error("body length is zero") 238 } 239 if roundTripped { 240 return &Response{ 241 Body: NoBody, 242 StatusCode: 200, 243 }, nil 244 } 245 roundTripped = true 246 return nil, http2noCachedConnError{} 247 }) 248 }, 249 }, 250 DialTLS: func(_, _ string) (net.Conn, error) { 251 tc, err := tls.Dial("tcp", addr, &tls.Config{ 252 InsecureSkipVerify: true, 253 NextProtos: []string{"foo"}, 254 }) 255 if err != nil { 256 return nil, err 257 } 258 if err := tc.Handshake(); err != nil { 259 return nil, err 260 } 261 return tc, nil 262 }, 263 } 264 c := &Client{Transport: tr} 265 _, err = c.Do(req) 266 if err != nil { 267 t.Error(err) 268 } 269} 270