• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// White-box tests for transport.go (in package http instead of http_test).
6
7package http
8
9import (
10	"bytes"
11	"context"
12	"crypto/tls"
13	"errors"
14	"io"
15	"net"
16	"net/http/internal/testcert"
17	"strings"
18	"testing"
19)
20
21// Issue 15446: incorrect wrapping of errors when server closes an idle connection.
22func TestTransportPersistConnReadLoopEOF(t *testing.T) {
23	ln := newLocalListener(t)
24	defer ln.Close()
25
26	connc := make(chan net.Conn, 1)
27	go func() {
28		defer close(connc)
29		c, err := ln.Accept()
30		if err != nil {
31			t.Error(err)
32			return
33		}
34		connc <- c
35	}()
36
37	tr := new(Transport)
38	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
39	req = req.WithT(t)
40	ctx, cancel := context.WithCancelCause(context.Background())
41	treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
42	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
43	pc, err := tr.getConn(treq, cm)
44	if err != nil {
45		t.Fatal(err)
46	}
47	defer pc.close(errors.New("test over"))
48
49	conn := <-connc
50	if conn == nil {
51		// Already called t.Error in the accept goroutine.
52		return
53	}
54	conn.Close() // simulate the server hanging up on the client
55
56	_, err = pc.roundTrip(treq)
57	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
58		t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
59	}
60
61	<-pc.closech
62	err = pc.closed
63	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
64		t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError, or nothingWrittenError", err, err)
65	}
66}
67
68func isNothingWrittenError(err error) bool {
69	_, ok := err.(nothingWrittenError)
70	return ok
71}
72
73func isTransportReadFromServerError(err error) bool {
74	_, ok := err.(transportReadFromServerError)
75	return ok
76}
77
78func newLocalListener(t *testing.T) net.Listener {
79	ln, err := net.Listen("tcp", "127.0.0.1:0")
80	if err != nil {
81		ln, err = net.Listen("tcp6", "[::1]:0")
82	}
83	if err != nil {
84		t.Fatal(err)
85	}
86	return ln
87}
88
89func dummyRequest(method string) *Request {
90	req, err := NewRequest(method, "http://fake.tld/", nil)
91	if err != nil {
92		panic(err)
93	}
94	return req
95}
96func dummyRequestWithBody(method string) *Request {
97	req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
98	if err != nil {
99		panic(err)
100	}
101	return req
102}
103
104func dummyRequestWithBodyNoGetBody(method string) *Request {
105	req := dummyRequestWithBody(method)
106	req.GetBody = nil
107	return req
108}
109
110// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
111type issue22091Error struct{}
112
113func (issue22091Error) IsHTTP2NoCachedConnError() {}
114func (issue22091Error) Error() string             { return "issue22091Error" }
115
116func TestTransportShouldRetryRequest(t *testing.T) {
117	tests := []struct {
118		pc  *persistConn
119		req *Request
120
121		err  error
122		want bool
123	}{
124		0: {
125			pc:   &persistConn{reused: false},
126			req:  dummyRequest("POST"),
127			err:  nothingWrittenError{},
128			want: false,
129		},
130		1: {
131			pc:   &persistConn{reused: true},
132			req:  dummyRequest("POST"),
133			err:  nothingWrittenError{},
134			want: true,
135		},
136		2: {
137			pc:   &persistConn{reused: true},
138			req:  dummyRequest("POST"),
139			err:  http2ErrNoCachedConn,
140			want: true,
141		},
142		3: {
143			pc:   nil,
144			req:  nil,
145			err:  issue22091Error{}, // like an external http2ErrNoCachedConn
146			want: true,
147		},
148		4: {
149			pc:   &persistConn{reused: true},
150			req:  dummyRequest("POST"),
151			err:  errMissingHost,
152			want: false,
153		},
154		5: {
155			pc:   &persistConn{reused: true},
156			req:  dummyRequest("POST"),
157			err:  transportReadFromServerError{},
158			want: false,
159		},
160		6: {
161			pc:   &persistConn{reused: true},
162			req:  dummyRequest("GET"),
163			err:  transportReadFromServerError{},
164			want: true,
165		},
166		7: {
167			pc:   &persistConn{reused: true},
168			req:  dummyRequest("GET"),
169			err:  errServerClosedIdle,
170			want: true,
171		},
172		8: {
173			pc:   &persistConn{reused: true},
174			req:  dummyRequestWithBody("POST"),
175			err:  nothingWrittenError{},
176			want: true,
177		},
178		9: {
179			pc:   &persistConn{reused: true},
180			req:  dummyRequestWithBodyNoGetBody("POST"),
181			err:  nothingWrittenError{},
182			want: false,
183		},
184	}
185	for i, tt := range tests {
186		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
187		if got != tt.want {
188			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
189		}
190	}
191}
192
193type roundTripFunc func(r *Request) (*Response, error)
194
195func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
196	return f(r)
197}
198
199// Issue 25009
200func TestTransportBodyAltRewind(t *testing.T) {
201	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
202	if err != nil {
203		t.Fatal(err)
204	}
205	ln := newLocalListener(t)
206	defer ln.Close()
207
208	go func() {
209		tln := tls.NewListener(ln, &tls.Config{
210			NextProtos:   []string{"foo"},
211			Certificates: []tls.Certificate{cert},
212		})
213		for i := 0; i < 2; i++ {
214			sc, err := tln.Accept()
215			if err != nil {
216				t.Error(err)
217				return
218			}
219			if err := sc.(*tls.Conn).Handshake(); err != nil {
220				t.Error(err)
221				return
222			}
223			sc.Close()
224		}
225	}()
226
227	addr := ln.Addr().String()
228	req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
229	roundTripped := false
230	tr := &Transport{
231		DisableKeepAlives: true,
232		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
233			"foo": func(authority string, c *tls.Conn) RoundTripper {
234				return roundTripFunc(func(r *Request) (*Response, error) {
235					n, _ := io.Copy(io.Discard, r.Body)
236					if n == 0 {
237						t.Error("body length is zero")
238					}
239					if roundTripped {
240						return &Response{
241							Body:       NoBody,
242							StatusCode: 200,
243						}, nil
244					}
245					roundTripped = true
246					return nil, http2noCachedConnError{}
247				})
248			},
249		},
250		DialTLS: func(_, _ string) (net.Conn, error) {
251			tc, err := tls.Dial("tcp", addr, &tls.Config{
252				InsecureSkipVerify: true,
253				NextProtos:         []string{"foo"},
254			})
255			if err != nil {
256				return nil, err
257			}
258			if err := tc.Handshake(); err != nil {
259				return nil, err
260			}
261			return tc, nil
262		},
263	}
264	c := &Client{Transport: tr}
265	_, err = c.Do(req)
266	if err != nil {
267		t.Error(err)
268	}
269}
270