• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package oauth2
6
7import (
8	"errors"
9	"io"
10	"net/http"
11	"sync"
12)
13
14// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
15// wrapping a base RoundTripper and adding an Authorization header
16// with a token from the supplied Sources.
17//
18// Transport is a low-level mechanism. Most code will use the
19// higher-level Config.Client method instead.
20type Transport struct {
21	// Source supplies the token to add to outgoing requests'
22	// Authorization headers.
23	Source TokenSource
24
25	// Base is the base RoundTripper used to make HTTP requests.
26	// If nil, http.DefaultTransport is used.
27	Base http.RoundTripper
28
29	mu     sync.Mutex                      // guards modReq
30	modReq map[*http.Request]*http.Request // original -> modified
31}
32
33// RoundTrip authorizes and authenticates the request with an
34// access token. If no token exists or token is expired,
35// tries to refresh/fetch a new token.
36func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
37	if t.Source == nil {
38		return nil, errors.New("oauth2: Transport's Source is nil")
39	}
40	token, err := t.Source.Token()
41	if err != nil {
42		return nil, err
43	}
44
45	req2 := cloneRequest(req) // per RoundTripper contract
46	token.SetAuthHeader(req2)
47	t.setModReq(req, req2)
48	res, err := t.base().RoundTrip(req2)
49	if err != nil {
50		t.setModReq(req, nil)
51		return nil, err
52	}
53	res.Body = &onEOFReader{
54		rc: res.Body,
55		fn: func() { t.setModReq(req, nil) },
56	}
57	return res, nil
58}
59
60// CancelRequest cancels an in-flight request by closing its connection.
61func (t *Transport) CancelRequest(req *http.Request) {
62	type canceler interface {
63		CancelRequest(*http.Request)
64	}
65	if cr, ok := t.base().(canceler); ok {
66		t.mu.Lock()
67		modReq := t.modReq[req]
68		delete(t.modReq, req)
69		t.mu.Unlock()
70		cr.CancelRequest(modReq)
71	}
72}
73
74func (t *Transport) base() http.RoundTripper {
75	if t.Base != nil {
76		return t.Base
77	}
78	return http.DefaultTransport
79}
80
81func (t *Transport) setModReq(orig, mod *http.Request) {
82	t.mu.Lock()
83	defer t.mu.Unlock()
84	if t.modReq == nil {
85		t.modReq = make(map[*http.Request]*http.Request)
86	}
87	if mod == nil {
88		delete(t.modReq, orig)
89	} else {
90		t.modReq[orig] = mod
91	}
92}
93
94// cloneRequest returns a clone of the provided *http.Request.
95// The clone is a shallow copy of the struct and its Header map.
96func cloneRequest(r *http.Request) *http.Request {
97	// shallow copy of the struct
98	r2 := new(http.Request)
99	*r2 = *r
100	// deep copy of the Header
101	r2.Header = make(http.Header, len(r.Header))
102	for k, s := range r.Header {
103		r2.Header[k] = append([]string(nil), s...)
104	}
105	return r2
106}
107
108type onEOFReader struct {
109	rc io.ReadCloser
110	fn func()
111}
112
113func (r *onEOFReader) Read(p []byte) (n int, err error) {
114	n, err = r.rc.Read(p)
115	if err == io.EOF {
116		r.runFunc()
117	}
118	return
119}
120
121func (r *onEOFReader) Close() error {
122	err := r.rc.Close()
123	r.runFunc()
124	return err
125}
126
127func (r *onEOFReader) runFunc() {
128	if fn := r.fn; fn != nil {
129		fn()
130		r.fn = nil
131	}
132}
133