diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 53885e1e..eef93c28 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -10,6 +10,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "golang.org/x/net/http/httpguts" @@ -25,6 +26,11 @@ type roundTripCloser interface { io.Closer } +type roundTripCloserWithCount struct { + roundTripCloser + useCount atomic.Int64 +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { mutex sync.Mutex @@ -82,7 +88,7 @@ type RoundTripper struct { MaxResponseHeaderBytes int64 newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests - clients map[string]roundTripCloser + clients map[string]*roundTripCloserWithCount udpConn *net.UDPConn } @@ -143,6 +149,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } + defer cl.useCount.Add(-1) rsp, err := cl.RoundTripOpt(req, opt) if err != nil { r.removeClient(hostname) @@ -160,12 +167,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]roundTripCloser) + r.clients = make(map[string]*roundTripCloserWithCount) } client, ok := r.clients[hostname] @@ -188,7 +195,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri } dial = r.makeDialer() } - client, err = newCl( + c, err := newCl( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -204,10 +211,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri if err != nil { return nil, false, err } + client = &roundTripCloserWithCount{roundTripCloser: c} r.clients[hostname] = client } else if client.HandshakeComplete() { isReused = true } + client.useCount.Add(1) return client, isReused, nil } @@ -276,3 +285,14 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg) } } + +func (r *RoundTripper) CloseIdleConnections() { + r.mutex.Lock() + defer r.mutex.Unlock() + for hostname, client := range r.clients { + if client.useCount.Load() == 0 { + client.Close() + delete(r.clients, hostname) + } + } +} diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 14388b5f..8e859de5 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -304,10 +305,10 @@ var _ = Describe("RoundTripper", func() { Context("closing", func() { It("closes", func() { - rt.clients = make(map[string]roundTripCloser) + rt.clients = make(map[string]*roundTripCloserWithCount) cl := NewMockRoundTripCloser(mockCtrl) cl.EXPECT().Close() - rt.clients["foo.bar"] = cl + rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}} err := rt.Close() Expect(err).ToNot(HaveOccurred()) Expect(len(rt.clients)).To(BeZero()) @@ -319,6 +320,53 @@ var _ = Describe("RoundTripper", func() { Expect(err).ToNot(HaveOccurred()) Expect(len(rt.clients)).To(BeZero()) }) + + It("closes idle connections", func() { + Expect(len(rt.clients)).To(Equal(0)) + req1, err := http.NewRequest("GET", "https://site1.com", nil) + Expect(err).ToNot(HaveOccurred()) + req2, err := http.NewRequest("GET", "https://site2.com", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(req1.Host).ToNot(Equal(req2.Host)) + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + req1 = req1.WithContext(ctx1) + req2 = req2.WithContext(ctx2) + roundTripCalled := make(chan struct{}) + reqFinished := make(chan struct{}) + rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().Close() + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) { + roundTripCalled <- struct{}{} + <-r.Context().Done() + return nil, nil + }) + return cl, nil + } + go func() { + rt.RoundTrip(req1) + reqFinished <- struct{}{} + }() + go func() { + rt.RoundTrip(req2) + reqFinished <- struct{}{} + }() + <-roundTripCalled + <-roundTripCalled + // Both two requests are started. + Expect(len(rt.clients)).To(Equal(2)) + cancel1() + <-reqFinished + // req1 is finished + rt.CloseIdleConnections() + Expect(len(rt.clients)).To(Equal(1)) + cancel2() + <-reqFinished + // all requests are finished + rt.CloseIdleConnections() + Expect(len(rt.clients)).To(Equal(0)) + }) }) Context("reusing udpconn", func() {