From 86a1234c870c1f91b7069112ff9ae59179aff027 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 25 Mar 2023 09:19:23 +1100 Subject: [PATCH] make EarlyListener a struct, not an interface --- http3/mock_quic_early_listener_test.go | 80 ++++++++++++++++++++++++++ http3/mockgen.go | 2 + http3/server.go | 36 ++++++++---- http3/server_test.go | 51 +++++++--------- integrationtests/self/hotswap_test.go | 6 +- integrationtests/self/zero_rtt_test.go | 4 +- interface.go | 11 ---- internal/mocks/mockgen.go | 1 - internal/mocks/quic/early_listener.go | 80 -------------------------- interop/http09/http_test.go | 4 +- interop/http09/server.go | 2 +- server.go | 33 ++++++++--- server_test.go | 28 ++++----- 13 files changed, 174 insertions(+), 164 deletions(-) create mode 100644 http3/mock_quic_early_listener_test.go delete mode 100644 internal/mocks/quic/early_listener.go diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go new file mode 100644 index 00000000..ab40f060 --- /dev/null +++ b/http3/mock_quic_early_listener_test.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/http3 (interfaces: QUICEarlyListener) + +// Package http3 is a generated GoMock package. +package http3 + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + quic "github.com/quic-go/quic-go" +) + +// MockQUICEarlyListener is a mock of QUICEarlyListener interface. +type MockQUICEarlyListener struct { + ctrl *gomock.Controller + recorder *MockQUICEarlyListenerMockRecorder +} + +// MockQUICEarlyListenerMockRecorder is the mock recorder for MockQUICEarlyListener. +type MockQUICEarlyListenerMockRecorder struct { + mock *MockQUICEarlyListener +} + +// NewMockQUICEarlyListener creates a new mock instance. +func NewMockQUICEarlyListener(ctrl *gomock.Controller) *MockQUICEarlyListener { + mock := &MockQUICEarlyListener{ctrl: ctrl} + mock.recorder = &MockQUICEarlyListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQUICEarlyListener) EXPECT() *MockQUICEarlyListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockQUICEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept", arg0) + ret0, _ := ret[0].(quic.EarlyConnection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICEarlyListener)(nil).Accept), arg0) +} + +// Addr mocks base method. +func (m *MockQUICEarlyListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockQUICEarlyListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockQUICEarlyListener)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockQUICEarlyListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockQUICEarlyListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQUICEarlyListener)(nil).Close)) +} diff --git a/http3/mockgen.go b/http3/mockgen.go index cb370373..38939e60 100644 --- a/http3/mockgen.go +++ b/http3/mockgen.go @@ -4,3 +4,5 @@ package http3 //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" type RoundTripCloser = roundTripCloser + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener" diff --git a/http3/server.go b/http3/server.go index e74247ab..09b62a2b 100644 --- a/http3/server.go +++ b/http3/server.go @@ -23,8 +23,12 @@ import ( // allows mocking of quic.Listen and quic.ListenAddr var ( - quicListen = quic.ListenEarly - quicListenAddr = quic.ListenAddrEarly + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { + return quic.ListenEarly(conn, tlsConf, config) + } + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { + return quic.ListenAddrEarly(addr, tlsConf, config) + } ) const ( @@ -44,6 +48,15 @@ const ( streamTypeQPACKDecoderStream = 3 ) +// A QUICEarlyListener listens for incoming QUIC connections. +type QUICEarlyListener interface { + Accept(context.Context) (quic.EarlyConnection, error) + Addr() net.Addr + io.Closer +} + +var _ QUICEarlyListener = &quic.EarlyListener{} + func versionToALPN(v protocol.VersionNumber) string { //nolint:exhaustive // These are all the versions we care about. switch v { @@ -193,7 +206,7 @@ type Server struct { UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) mutex sync.RWMutex - listeners map[*quic.EarlyListener]listenerInfo + listeners map[*QUICEarlyListener]listenerInfo closed bool @@ -249,7 +262,7 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error { // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // and use it to construct a http3-friendly QUIC listener. // Closing the server does close the listener. -func (s *Server) ServeListener(ln quic.EarlyListener) error { +func (s *Server) ServeListener(ln QUICEarlyListener) error { if err := s.addListener(&ln); err != nil { return err } @@ -283,7 +296,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { quicConf.EnableDatagrams = true } - var ln quic.EarlyListener + var ln QUICEarlyListener var err error if conn == nil { addr := s.Addr @@ -305,7 +318,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { return err } -func (s *Server) serveListener(ln quic.EarlyListener) error { +func (s *Server) serveListener(ln QUICEarlyListener) error { for { conn, err := ln.Accept(context.Background()) if err != nil { @@ -391,7 +404,7 @@ func (s *Server) generateAltSvcHeader() { // We store a pointer to interface in the map set. This is safe because we only // call trackListener via Serve and can track+defer untrack the same pointer to // local variable there. We never need to compare a Listener from another caller. -func (s *Server) addListener(l *quic.EarlyListener) error { +func (s *Server) addListener(l *QUICEarlyListener) error { s.mutex.Lock() defer s.mutex.Unlock() @@ -402,25 +415,24 @@ func (s *Server) addListener(l *quic.EarlyListener) error { s.logger = utils.DefaultLogger.WithPrefix("server") } if s.listeners == nil { - s.listeners = make(map[*quic.EarlyListener]listenerInfo) + s.listeners = make(map[*QUICEarlyListener]listenerInfo) } if port, err := extractPort((*l).Addr().String()); err == nil { s.listeners[l] = listenerInfo{port} } else { - s.logger.Errorf( - "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) s.listeners[l] = listenerInfo{} } s.generateAltSvcHeader() return nil } -func (s *Server) removeListener(l *quic.EarlyListener) { +func (s *Server) removeListener(l *QUICEarlyListener) { s.mutex.Lock() + defer s.mutex.Unlock() delete(s.listeners, l) s.generateAltSvcHeader() - s.mutex.Unlock() } func (s *Server) handleConn(conn quic.Connection) error { diff --git a/http3/server_test.go b/http3/server_test.go index bbf8c9cc..572446ea 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -28,34 +28,25 @@ import ( gmtypes "github.com/onsi/gomega/types" ) -type mockAddr struct { - addr string -} +type mockAddr struct{ addr string } -func (ma *mockAddr) Network() string { - return "udp" -} - -func (ma *mockAddr) String() string { - return ma.addr -} +func (ma *mockAddr) Network() string { return "udp" } +func (ma *mockAddr) String() string { return ma.addr } type mockAddrListener struct { - *mockquic.MockEarlyListener + *MockQUICEarlyListener addr *mockAddr } func (m *mockAddrListener) Addr() net.Addr { - _ = m.MockEarlyListener.Addr() + _ = m.MockQUICEarlyListener.Addr() return m.addr } func newMockAddrListener(addr string) *mockAddrListener { return &mockAddrListener{ - MockEarlyListener: mockquic.NewMockEarlyListener(mockCtrl), - addr: &mockAddr{ - addr: addr, - }, + MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), + addr: &mockAddr{addr: addr}, } } @@ -791,20 +782,20 @@ var _ = Describe("Server", func() { s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionDraft29}} }) - var ln1 quic.EarlyListener - var ln2 quic.EarlyListener + var ln1 QUICEarlyListener + var ln2 QUICEarlyListener expected := http.Header{ "Alt-Svc": {`h3-29=":443"; ma=2592000`}, } - addListener := func(addr string, ln *quic.EarlyListener) { + addListener := func(addr string, ln *QUICEarlyListener) { mln := newMockAddrListener(addr) mln.EXPECT().Addr() *ln = mln s.addListener(ln) } - removeListener := func(ln *quic.EarlyListener) { + removeListener := func(ln *QUICEarlyListener) { s.removeListener(ln) } @@ -951,7 +942,7 @@ var _ = Describe("Server", func() { It("sets the GetConfigForClient callback if no tls.Config is given", func() { var receivedConf *tls.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { receivedConf = tlsConf return nil, errors.New("listen err") } @@ -1021,7 +1012,7 @@ var _ = Describe("Server", func() { It("serves a packet conn", func() { ln := newMockAddrListener(":443") conn := &net.UDPConn{} - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { Expect(c).To(Equal(conn)) return ln, nil } @@ -1052,12 +1043,12 @@ var _ = Describe("Server", func() { It("serves two packet conns", func() { ln1 := newMockAddrListener(":443") ln2 := newMockAddrListener(":8443") - lns := make(chan quic.EarlyListener, 2) + lns := make(chan QUICEarlyListener, 2) lns <- ln1 lns <- ln2 conn1 := &net.UDPConn{} conn2 := &net.UDPConn{} - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { return <-lns, nil } @@ -1111,7 +1102,7 @@ var _ = Describe("Server", func() { It("serves a listener", func() { var called int32 ln := newMockAddrListener(":443") - quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { atomic.StoreInt32(&called, 1) return ln, nil } @@ -1142,10 +1133,10 @@ var _ = Describe("Server", func() { var called int32 ln1 := newMockAddrListener(":443") ln2 := newMockAddrListener(":8443") - lns := make(chan quic.EarlyListener, 2) + lns := make(chan QUICEarlyListener, 2) lns <- ln1 lns <- ln2 - quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { atomic.StoreInt32(&called, 1) return <-lns, nil } @@ -1225,7 +1216,7 @@ var _ = Describe("Server", func() { It("uses the quic.Config to start the QUIC server", func() { conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} var receivedConf *quic.Config - quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { receivedConf = config return nil, errors.New("listen err") } @@ -1241,7 +1232,7 @@ var _ = Describe("Server", func() { It("errors when listening fails", func() { testErr := errors.New("listen error") - quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { return nil, testErr } fullpem, privkey := testdata.GetCertificatePaths() @@ -1251,7 +1242,7 @@ var _ = Describe("Server", func() { It("supports H3_DATAGRAM", func() { s.EnableDatagrams = true var receivedConf *quic.Config - quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { receivedConf = config return nil, errors.New("listen err") } diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go index 3bb008b9..6cd73079 100644 --- a/integrationtests/self/hotswap_test.go +++ b/integrationtests/self/hotswap_test.go @@ -18,14 +18,14 @@ import ( ) type listenerWrapper struct { - quic.EarlyListener + http3.QUICEarlyListener listenerClosed bool count int32 } func (ln *listenerWrapper) Close() error { ln.listenerClosed = true - return ln.EarlyListener.Close() + return ln.QUICEarlyListener.Close() } func (ln *listenerWrapper) Faker() *fakeClosingListener { @@ -91,7 +91,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() { tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil)) - ln = &listenerWrapper{EarlyListener: quicln} + ln = &listenerWrapper{QUICEarlyListener: quicln} Expect(err).NotTo(HaveOccurred()) port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 70b1199e..1559c176 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -99,7 +99,7 @@ var _ = Describe("0-RTT", func() { } transfer0RTTData := func( - ln quic.EarlyListener, + ln *quic.EarlyListener, proxyPort int, clientTLSConf *tls.Config, clientConf *quic.Config, @@ -147,7 +147,7 @@ var _ = Describe("0-RTT", func() { } check0RTTRejected := func( - ln quic.EarlyListener, + ln *quic.EarlyListener, proxyPort int, clientConf *tls.Config, ) { diff --git a/interface.go b/interface.go index 757a71ea..43101448 100644 --- a/interface.go +++ b/interface.go @@ -345,14 +345,3 @@ type ConnectionState struct { SupportsDatagrams bool Version VersionNumber } - -// An EarlyListener listens for incoming QUIC connections, -// and returns them before the handshake completes. -type EarlyListener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new early connections. It should be called in a loop. - Accept(context.Context) (EarlyConnection, error) -} diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 4f084302..8717fce0 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -2,7 +2,6 @@ package mocks //go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" //go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package mockquic -destination quic/early_listener.go github.com/quic-go/quic-go EarlyListener" //go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/quic-go/quic-go/logging Tracer" //go:generate sh -c "go run github.com/golang/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/quic-go/quic-go/logging ConnectionTracer" //go:generate sh -c "go run github.com/golang/mock/mockgen -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" diff --git a/internal/mocks/quic/early_listener.go b/internal/mocks/quic/early_listener.go deleted file mode 100644 index b7cb008c..00000000 --- a/internal/mocks/quic/early_listener.go +++ /dev/null @@ -1,80 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: EarlyListener) - -// Package mockquic is a generated GoMock package. -package mockquic - -import ( - context "context" - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - quic "github.com/quic-go/quic-go" -) - -// MockEarlyListener is a mock of EarlyListener interface. -type MockEarlyListener struct { - ctrl *gomock.Controller - recorder *MockEarlyListenerMockRecorder -} - -// MockEarlyListenerMockRecorder is the mock recorder for MockEarlyListener. -type MockEarlyListenerMockRecorder struct { - mock *MockEarlyListener -} - -// NewMockEarlyListener creates a new mock instance. -func NewMockEarlyListener(ctrl *gomock.Controller) *MockEarlyListener { - mock := &MockEarlyListener{ctrl: ctrl} - mock.recorder = &MockEarlyListenerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { - return m.recorder -} - -// Accept mocks base method. -func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Accept", arg0) - ret0, _ := ret[0].(quic.EarlyConnection) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Accept indicates an expected call of Accept. -func (mr *MockEarlyListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockEarlyListener)(nil).Accept), arg0) -} - -// Addr mocks base method. -func (m *MockEarlyListener) Addr() net.Addr { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Addr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// Addr indicates an expected call of Addr. -func (mr *MockEarlyListenerMockRecorder) Addr() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockEarlyListener)(nil).Addr)) -} - -// Close mocks base method. -func (m *MockEarlyListener) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockEarlyListenerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEarlyListener)(nil).Close)) -} diff --git a/interop/http09/http_test.go b/interop/http09/http_test.go index 87d916ef..f2d48994 100644 --- a/interop/http09/http_test.go +++ b/interop/http09/http_test.go @@ -36,8 +36,8 @@ var _ = Describe("HTTP 0.9 integration tests", func() { defer close(done) _ = server.ListenAndServe() }() - var ln quic.EarlyListener - Eventually(func() quic.EarlyListener { + var ln *quic.EarlyListener + Eventually(func() *quic.EarlyListener { server.mutex.Lock() defer server.mutex.Unlock() ln = server.listener diff --git a/interop/http09/server.go b/interop/http09/server.go index cbe852f6..b7b510d8 100644 --- a/interop/http09/server.go +++ b/interop/http09/server.go @@ -40,7 +40,7 @@ type Server struct { QuicConfig *quic.Config mutex sync.Mutex - listener quic.EarlyListener + listener *quic.EarlyListener } // Close closes the server. diff --git a/server.go b/server.go index 0c8857aa..3543e546 100644 --- a/server.go +++ b/server.go @@ -137,12 +137,29 @@ func (l *Listener) Addr() net.Addr { return l.baseServer.Addr() } -type earlyServer struct{ *baseServer } +// An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes. +// For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data. +// This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified. +// For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the +// 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the +// client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker. +type EarlyListener struct { + baseServer *baseServer +} -var _ EarlyListener = &earlyServer{} +// Accept returns a new connections. It should be called in a loop. +func (l *EarlyListener) Accept(ctx context.Context) (EarlyConnection, error) { + return l.baseServer.accept(ctx) +} -func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { - return s.baseServer.accept(ctx) +// Close the server. All active connections will be closed. +func (l *EarlyListener) Close() error { + return l.baseServer.Close() +} + +// Addr returns the local network addr that the server is listening on. +func (l *EarlyListener) Addr() net.Addr { + return l.baseServer.Addr() } // ListenAddr creates a QUIC server listening on a given address. @@ -157,12 +174,12 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, er } // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. -func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { s, err := listenAddr(addr, tlsConf, config, true) if err != nil { return nil, err } - return &earlyServer{s}, nil + return &EarlyListener{baseServer: s}, nil } func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { @@ -201,12 +218,12 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener } // ListenEarly works like Listen, but it returns connections before the handshake completes. -func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { s, err := listen(conn, tlsConf, config, true) if err != nil { return nil, err } - return &earlyServer{s}, nil + return &EarlyListener{baseServer: s}, nil } func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { diff --git a/server_test.go b/server_test.go index 8b843a13..e9f751e6 100644 --- a/server_test.go +++ b/server_test.go @@ -1051,16 +1051,16 @@ var _ = Describe("Server", func() { Context("server accepting connections that haven't completed the handshake", func() { var ( - serv *earlyServer + serv *EarlyListener phm *MockPacketHandlerManager ) BeforeEach(func() { - ln, err := ListenEarly(conn, tlsConf, nil) + var err error + serv, err = ListenEarly(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) - serv = ln.(*earlyServer) phm = NewMockPacketHandlerManager(mockCtrl) - serv.connHandler = phm + serv.baseServer.connHandler = phm }) AfterEach(func() { @@ -1081,7 +1081,7 @@ var _ = Describe("Server", func() { }() ready := make(chan struct{}) - serv.newConn = func( + serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, @@ -1111,7 +1111,7 @@ var _ = Describe("Server", func() { fn() return true }) - serv.handleInitialImpl( + serv.baseServer.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) @@ -1123,7 +1123,7 @@ var _ = Describe("Server", func() { It("rejects new connection attempts if the accept queue is full", func() { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - serv.newConn = func( + serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, @@ -1158,10 +1158,10 @@ var _ = Describe("Server", func() { return true }).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - serv.handlePacket(getInitialWithRandomDestConnID()) + serv.baseServer.handlePacket(getInitialWithRandomDestConnID()) } - Eventually(func() int32 { return atomic.LoadInt32(&serv.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) + Eventually(func() int32 { return atomic.LoadInt32(&serv.baseServer.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) @@ -1177,7 +1177,7 @@ var _ = Describe("Server", func() { Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) return len(b), nil }) - serv.handlePacket(p) + serv.baseServer.handlePacket(p) Eventually(done).Should(BeClosed()) }) @@ -1186,7 +1186,7 @@ var _ = Describe("Server", func() { ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQUICConn(mockCtrl) - serv.newConn = func( + serv.baseServer.newConn = func( _ sendConn, runner connRunner, _ protocol.ConnectionID, @@ -1218,7 +1218,7 @@ var _ = Describe("Server", func() { fn() return true }) - serv.handlePacket(p) + serv.baseServer.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) Eventually(connCreated).Should(BeClosed()) @@ -1243,7 +1243,7 @@ var _ = Describe("Server", func() { Context("0-RTT", func() { var ( - serv *earlyServer + serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer ) @@ -1252,8 +1252,8 @@ var _ = Describe("Server", func() { tracer = mocklogging.NewMockTracer(mockCtrl) ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer}) Expect(err).ToNot(HaveOccurred()) - serv = ln.(*earlyServer) phm = NewMockPacketHandlerManager(mockCtrl) + serv = ln.baseServer serv.connHandler = phm })