mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-02 19:57:35 +03:00
369 lines
12 KiB
Go
369 lines
12 KiB
Go
package http3
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
type mockBody struct {
|
|
reader bytes.Reader
|
|
readErr error
|
|
closeErr error
|
|
closed bool
|
|
}
|
|
|
|
// make sure the mockBody can be used as a http.Request.Body
|
|
var _ io.ReadCloser = &mockBody{}
|
|
|
|
func (m *mockBody) Read(p []byte) (int, error) {
|
|
if m.readErr != nil {
|
|
return 0, m.readErr
|
|
}
|
|
return m.reader.Read(p)
|
|
}
|
|
|
|
func (m *mockBody) SetData(data []byte) {
|
|
m.reader = *bytes.NewReader(data)
|
|
}
|
|
|
|
func (m *mockBody) Close() error {
|
|
m.closed = true
|
|
return m.closeErr
|
|
}
|
|
|
|
var _ = Describe("RoundTripper", func() {
|
|
var (
|
|
rt *RoundTripper
|
|
req *http.Request
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
rt = &RoundTripper{}
|
|
var err error
|
|
req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
Context("dialing hosts", func() {
|
|
It("creates new clients", func() {
|
|
testErr := errors.New("test err")
|
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
|
|
return cl, nil
|
|
}
|
|
_, err = rt.RoundTrip(req)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("uses the quic.Config, if provided", func() {
|
|
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
|
|
var receivedConfig *quic.Config
|
|
rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
|
|
receivedConfig = config
|
|
return nil, errors.New("handshake error")
|
|
}
|
|
rt.QuicConfig = config
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("handshake error"))
|
|
Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
|
|
})
|
|
|
|
It("uses the custom dialer, if provided", func() {
|
|
var dialed bool
|
|
dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
|
dialed = true
|
|
return nil, errors.New("handshake error")
|
|
}
|
|
rt.Dial = dialer
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("handshake error"))
|
|
Expect(dialed).To(BeTrue())
|
|
})
|
|
})
|
|
|
|
Context("reusing clients", func() {
|
|
var req1, req2 *http.Request
|
|
|
|
BeforeEach(func() {
|
|
var err error
|
|
req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(req1.URL).ToNot(Equal(req2.URL))
|
|
})
|
|
|
|
It("reuses existing clients", func() {
|
|
var count int
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
count++
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
return &http.Response{Request: req}, nil
|
|
}).Times(2)
|
|
cl.EXPECT().HandshakeComplete().Return(true)
|
|
return cl, nil
|
|
}
|
|
rsp1, err := rt.RoundTrip(req1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp1.Request.URL).To(Equal(req1.URL))
|
|
rsp2, err := rt.RoundTrip(req2)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp2.Request.URL).To(Equal(req2.URL))
|
|
Expect(count).To(Equal(1))
|
|
})
|
|
|
|
It("immediately removes a clients when a request errored", func() {
|
|
testErr := errors.New("test err")
|
|
|
|
var count int
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
count++
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
|
|
return cl, nil
|
|
}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError(testErr))
|
|
_, err = rt.RoundTrip(req2)
|
|
Expect(err).To(MatchError(testErr))
|
|
Expect(count).To(Equal(2))
|
|
})
|
|
|
|
It("recreates a client when a request times out", func() {
|
|
var reqCount int
|
|
cl1 := NewMockRoundTripCloser(mockCtrl)
|
|
cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
reqCount++
|
|
if reqCount == 1 { // the first request is successful...
|
|
Expect(req.URL).To(Equal(req1.URL))
|
|
return &http.Response{Request: req}, nil
|
|
}
|
|
// ... after that, the connection timed out in the background
|
|
Expect(req.URL).To(Equal(req2.URL))
|
|
return nil, &qerr.IdleTimeoutError{}
|
|
}).Times(2)
|
|
cl1.EXPECT().HandshakeComplete().Return(true)
|
|
cl2 := NewMockRoundTripCloser(mockCtrl)
|
|
cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
return &http.Response{Request: req}, nil
|
|
})
|
|
|
|
var count int
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
count++
|
|
if count == 1 {
|
|
return cl1, nil
|
|
}
|
|
return cl2, nil
|
|
}
|
|
rsp1, err := rt.RoundTrip(req1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
|
|
rsp2, err := rt.RoundTrip(req2)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
|
|
})
|
|
|
|
It("only issues a request once, even if a timeout error occurs", func() {
|
|
var count int
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
count++
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
|
|
return cl, nil
|
|
}
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
|
Expect(count).To(Equal(1))
|
|
})
|
|
|
|
It("handles a burst of requests", func() {
|
|
wait := make(chan struct{})
|
|
reqs := make(chan struct{}, 2)
|
|
var count int
|
|
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
|
count++
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
reqs <- struct{}{}
|
|
<-wait
|
|
return nil, &qerr.IdleTimeoutError{}
|
|
}).Times(2)
|
|
cl.EXPECT().HandshakeComplete()
|
|
return cl, nil
|
|
}
|
|
done := make(chan struct{}, 2)
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
defer func() { done <- struct{}{} }()
|
|
_, err := rt.RoundTrip(req1)
|
|
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
|
}()
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
defer func() { done <- struct{}{} }()
|
|
_, err := rt.RoundTrip(req2)
|
|
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
|
}()
|
|
// wait for both requests to be issued
|
|
Eventually(reqs).Should(Receive())
|
|
Eventually(reqs).Should(Receive())
|
|
close(wait) // now return the requests
|
|
Eventually(done).Should(Receive())
|
|
Eventually(done).Should(Receive())
|
|
Expect(count).To(Equal(1))
|
|
})
|
|
|
|
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
|
|
Expect(err).To(MatchError(ErrNoCachedConn))
|
|
})
|
|
})
|
|
|
|
Context("validating request", func() {
|
|
It("rejects plain HTTP requests", func() {
|
|
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
|
req.Body = &mockBody{}
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects requests without a URL", func() {
|
|
req.URL = nil
|
|
req.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: nil Request.URL"))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects request without a URL Host", func() {
|
|
req.URL.Host = ""
|
|
req.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: no Host in request URL"))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("doesn't try to close the body if the request doesn't have one", func() {
|
|
req.URL = nil
|
|
Expect(req.Body).To(BeNil())
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: nil Request.URL"))
|
|
})
|
|
|
|
It("rejects requests without a header", func() {
|
|
req.Header = nil
|
|
req.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: nil Request.Header"))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
|
|
It("rejects requests with invalid header name fields", func() {
|
|
req.Header.Add("foobär", "value")
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
|
|
})
|
|
|
|
It("rejects requests with invalid header name values", func() {
|
|
req.Header.Add("foo", string([]byte{0x7}))
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
|
|
})
|
|
|
|
It("rejects requests with an invalid request method", func() {
|
|
req.Method = "foobär"
|
|
req.Body = &mockBody{}
|
|
_, err := rt.RoundTrip(req)
|
|
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
|
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
|
})
|
|
})
|
|
|
|
Context("closing", func() {
|
|
It("closes", func() {
|
|
rt.clients = make(map[string]*roundTripCloserWithCount)
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().Close()
|
|
rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}}
|
|
err := rt.Close()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
})
|
|
|
|
It("closes a RoundTripper that has never been used", func() {
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
err := rt.Close()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(len(rt.clients)).To(BeZero())
|
|
})
|
|
|
|
It("closes idle connections", func() {
|
|
Expect(len(rt.clients)).To(Equal(0))
|
|
req1, err := http.NewRequest("GET", "https://site1.com", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
req2, err := http.NewRequest("GET", "https://site2.com", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(req1.Host).ToNot(Equal(req2.Host))
|
|
ctx1, cancel1 := context.WithCancel(context.Background())
|
|
ctx2, cancel2 := context.WithCancel(context.Background())
|
|
req1 = req1.WithContext(ctx1)
|
|
req2 = req2.WithContext(ctx2)
|
|
roundTripCalled := make(chan struct{})
|
|
reqFinished := make(chan struct{})
|
|
rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
|
cl := NewMockRoundTripCloser(mockCtrl)
|
|
cl.EXPECT().Close()
|
|
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
|
roundTripCalled <- struct{}{}
|
|
<-r.Context().Done()
|
|
return nil, nil
|
|
})
|
|
return cl, nil
|
|
}
|
|
go func() {
|
|
rt.RoundTrip(req1)
|
|
reqFinished <- struct{}{}
|
|
}()
|
|
go func() {
|
|
rt.RoundTrip(req2)
|
|
reqFinished <- struct{}{}
|
|
}()
|
|
<-roundTripCalled
|
|
<-roundTripCalled
|
|
// Both two requests are started.
|
|
Expect(len(rt.clients)).To(Equal(2))
|
|
cancel1()
|
|
<-reqFinished
|
|
// req1 is finished
|
|
rt.CloseIdleConnections()
|
|
Expect(len(rt.clients)).To(Equal(1))
|
|
cancel2()
|
|
<-reqFinished
|
|
// all requests are finished
|
|
rt.CloseIdleConnections()
|
|
Expect(len(rt.clients)).To(Equal(0))
|
|
})
|
|
})
|
|
})
|