implement http3.RoundTripper.CloseIdleConnections (#3820)

* implement CloseIdleConnections

* nit

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Glonee 2023-05-15 15:12:00 +08:00 committed by GitHub
parent e9fea08613
commit cec79d338c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 6 deletions

View file

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
@ -25,6 +26,11 @@ type roundTripCloser interface {
io.Closer io.Closer
} }
type roundTripCloserWithCount struct {
roundTripCloser
useCount atomic.Int64
}
// RoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type RoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
@ -82,7 +88,7 @@ type RoundTripper struct {
MaxResponseHeaderBytes int64 MaxResponseHeaderBytes int64
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]roundTripCloser clients map[string]*roundTripCloserWithCount
udpConn *net.UDPConn udpConn *net.UDPConn
} }
@ -143,6 +149,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cl.useCount.Add(-1)
rsp, err := cl.RoundTripOpt(req, opt) rsp, err := cl.RoundTripOpt(req, opt)
if err != nil { if err != nil {
r.removeClient(hostname) r.removeClient(hostname)
@ -160,12 +167,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{}) return r.RoundTripOpt(req, RoundTripOpt{})
} }
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]roundTripCloser) r.clients = make(map[string]*roundTripCloserWithCount)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
@ -188,7 +195,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
} }
dial = r.makeDialer() dial = r.makeDialer()
} }
client, err = newCl( c, err := newCl(
hostname, hostname,
r.TLSClientConfig, r.TLSClientConfig,
&roundTripperOpts{ &roundTripperOpts{
@ -204,10 +211,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
client = &roundTripCloserWithCount{roundTripCloser: c}
r.clients[hostname] = client r.clients[hostname] = client
} else if client.HandshakeComplete() { } else if client.HandshakeComplete() {
isReused = true isReused = true
} }
client.useCount.Add(1)
return client, isReused, nil return client, isReused, nil
} }
@ -276,3 +285,14 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg)
} }
} }
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, client := range r.clients {
if client.useCount.Load() == 0 {
client.Close()
delete(r.clients, hostname)
}
}
}

View file

@ -8,6 +8,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"sync/atomic"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
@ -304,10 +305,10 @@ var _ = Describe("RoundTripper", func() {
Context("closing", func() { Context("closing", func() {
It("closes", func() { It("closes", func() {
rt.clients = make(map[string]roundTripCloser) rt.clients = make(map[string]*roundTripCloserWithCount)
cl := NewMockRoundTripCloser(mockCtrl) cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().Close() cl.EXPECT().Close()
rt.clients["foo.bar"] = cl rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
err := rt.Close() err := rt.Close()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero()) Expect(len(rt.clients)).To(BeZero())
@ -319,6 +320,53 @@ var _ = Describe("RoundTripper", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero()) Expect(len(rt.clients)).To(BeZero())
}) })
It("closes idle connections", func() {
Expect(len(rt.clients)).To(Equal(0))
req1, err := http.NewRequest("GET", "https://site1.com", nil)
Expect(err).ToNot(HaveOccurred())
req2, err := http.NewRequest("GET", "https://site2.com", nil)
Expect(err).ToNot(HaveOccurred())
Expect(req1.Host).ToNot(Equal(req2.Host))
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(context.Background())
req1 = req1.WithContext(ctx1)
req2 = req2.WithContext(ctx2)
roundTripCalled := make(chan struct{})
reqFinished := make(chan struct{})
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().Close()
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
roundTripCalled <- struct{}{}
<-r.Context().Done()
return nil, nil
})
return cl, nil
}
go func() {
rt.RoundTrip(req1)
reqFinished <- struct{}{}
}()
go func() {
rt.RoundTrip(req2)
reqFinished <- struct{}{}
}()
<-roundTripCalled
<-roundTripCalled
// Both two requests are started.
Expect(len(rt.clients)).To(Equal(2))
cancel1()
<-reqFinished
// req1 is finished
rt.CloseIdleConnections()
Expect(len(rt.clients)).To(Equal(1))
cancel2()
<-reqFinished
// all requests are finished
rt.CloseIdleConnections()
Expect(len(rt.clients)).To(Equal(0))
})
}) })
Context("reusing udpconn", func() { Context("reusing udpconn", func() {