mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
implement http3.RoundTripper.CloseIdleConnections (#3820)
* implement CloseIdleConnections * nit Co-authored-by: Marten Seemann <martenseemann@gmail.com> --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
e9fea08613
commit
cec79d338c
2 changed files with 74 additions and 6 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
|
@ -25,6 +26,11 @@ type roundTripCloser interface {
|
|||
io.Closer
|
||||
}
|
||||
|
||||
type roundTripCloserWithCount struct {
|
||||
roundTripCloser
|
||||
useCount atomic.Int64
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
@ -82,7 +88,7 @@ type RoundTripper struct {
|
|||
MaxResponseHeaderBytes int64
|
||||
|
||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||
clients map[string]roundTripCloser
|
||||
clients map[string]*roundTripCloserWithCount
|
||||
udpConn *net.UDPConn
|
||||
}
|
||||
|
||||
|
@ -143,6 +149,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cl.useCount.Add(-1)
|
||||
rsp, err := cl.RoundTripOpt(req, opt)
|
||||
if err != nil {
|
||||
r.removeClient(hostname)
|
||||
|
@ -160,12 +167,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) {
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]roundTripCloser)
|
||||
r.clients = make(map[string]*roundTripCloserWithCount)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
|
@ -188,7 +195,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
|
|||
}
|
||||
dial = r.makeDialer()
|
||||
}
|
||||
client, err = newCl(
|
||||
c, err := newCl(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{
|
||||
|
@ -204,10 +211,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
|
|||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
client = &roundTripCloserWithCount{roundTripCloser: c}
|
||||
r.clients[hostname] = client
|
||||
} else if client.HandshakeComplete() {
|
||||
isReused = true
|
||||
}
|
||||
client.useCount.Add(1)
|
||||
return client, isReused, nil
|
||||
}
|
||||
|
||||
|
@ -276,3 +285,14 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
|
|||
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RoundTripper) CloseIdleConnections() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for hostname, client := range r.clients {
|
||||
if client.useCount.Load() == 0 {
|
||||
client.Close()
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
@ -304,10 +305,10 @@ var _ = Describe("RoundTripper", func() {
|
|||
|
||||
Context("closing", func() {
|
||||
It("closes", func() {
|
||||
rt.clients = make(map[string]roundTripCloser)
|
||||
rt.clients = make(map[string]*roundTripCloserWithCount)
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().Close()
|
||||
rt.clients["foo.bar"] = cl
|
||||
rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
|
@ -319,6 +320,53 @@ var _ = Describe("RoundTripper", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
})
|
||||
|
||||
It("closes idle connections", func() {
|
||||
Expect(len(rt.clients)).To(Equal(0))
|
||||
req1, err := http.NewRequest("GET", "https://site1.com", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2, err := http.NewRequest("GET", "https://site2.com", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(req1.Host).ToNot(Equal(req2.Host))
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
req1 = req1.WithContext(ctx1)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
roundTripCalled := make(chan struct{})
|
||||
reqFinished := make(chan struct{})
|
||||
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().Close()
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
roundTripCalled <- struct{}{}
|
||||
<-r.Context().Done()
|
||||
return nil, nil
|
||||
})
|
||||
return cl, nil
|
||||
}
|
||||
go func() {
|
||||
rt.RoundTrip(req1)
|
||||
reqFinished <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
rt.RoundTrip(req2)
|
||||
reqFinished <- struct{}{}
|
||||
}()
|
||||
<-roundTripCalled
|
||||
<-roundTripCalled
|
||||
// Both two requests are started.
|
||||
Expect(len(rt.clients)).To(Equal(2))
|
||||
cancel1()
|
||||
<-reqFinished
|
||||
// req1 is finished
|
||||
rt.CloseIdleConnections()
|
||||
Expect(len(rt.clients)).To(Equal(1))
|
||||
cancel2()
|
||||
<-reqFinished
|
||||
// all requests are finished
|
||||
rt.CloseIdleConnections()
|
||||
Expect(len(rt.clients)).To(Equal(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("reusing udpconn", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue