diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index f4de3574..63c32677 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -20,7 +20,7 @@ var _ = Describe("Stream Cancellations", func() { const numStreams = 80 Context("canceling the read side", func() { - var server quic.Listener + var server *quic.Listener // The server accepts a single connection, and then opens numStreams unidirectional streams. // On each of these streams, it (tries to) write PRData. @@ -222,7 +222,7 @@ var _ = Describe("Stream Cancellations", func() { }) Context("canceling the write side", func() { - runClient := func(server quic.Listener) int32 /* number of canceled streams */ { + runClient := func(server *quic.Listener) int32 /* number of canceled streams */ { conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index ddf3cfc5..c835a2ef 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -34,7 +34,7 @@ func (c *connIDGenerator) ConnectionIDLen() int { var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } - runServer := func(conf *quic.Config) quic.Listener { + runServer := func(conf *quic.Config) *quic.Listener { GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 2dcb4b3b..b165aff0 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -14,7 +14,7 @@ import ( ) var _ = Describe("Stream deadline tests", func() { - setup := func() (quic.Listener, quic.Stream, quic.Stream) { + setup := func() (*quic.Listener, quic.Stream, quic.Stream) { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) strChan := make(chan quic.SendStream) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index f3265f1c..4eac657a 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -22,7 +22,7 @@ func randomDuration(min, max time.Duration) time.Duration { var _ = Describe("Drop Tests", func() { var ( proxy *quicproxy.QuicProxy - ln quic.Listener + ln *quic.Listener ) startListenerAndProxy := func(dropCallback quicproxy.DropCallback) { diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index e4d6d6b3..ae483771 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -32,7 +32,7 @@ type applicationProtocol struct { var _ = Describe("Handshake drop tests", func() { var ( proxy *quicproxy.QuicProxy - ln quic.Listener + ln *quic.Listener ) data := GeneratePRData(5000) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 7a5d000f..3274b84c 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -46,7 +46,7 @@ func (c *tokenStore) Pop(key string) *quic.ClientToken { var _ = Describe("Handshake tests", func() { var ( - server quic.Listener + server *quic.Listener serverConfig *quic.Config acceptStopped chan struct{} ) @@ -221,7 +221,7 @@ var _ = Describe("Handshake tests", func() { Context("rate limiting", func() { var ( - server quic.Listener + server *quic.Listener pconn net.PacketConn ) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 825ccb9f..b75d1656 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -14,7 +14,7 @@ import ( ) var _ = Describe("Multiplexing", func() { - runServer := func(ln quic.Listener) { + runServer := func(ln *quic.Listener) { go func() { defer GinkgoRecover() for { @@ -52,7 +52,7 @@ var _ = Describe("Multiplexing", func() { } Context("multiplexing clients on the same conn", func() { - getListener := func() quic.Listener { + getListener := func() *quic.Listener { ln, err := quic.ListenAddr( "localhost:0", getTLSConfig(), diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index 7482a9c9..97d9b981 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -15,7 +15,7 @@ import ( ) var _ = Describe("non-zero RTT", func() { - runServer := func() quic.Listener { + runServer := func() *quic.Listener { ln, err := quic.ListenAddr( "localhost:0", getTLSConfig(), diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 59484deb..0af14b8f 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -16,7 +16,7 @@ var _ = Describe("Bidirectional streams", func() { const numStreams = 300 var ( - server quic.Listener + server *quic.Listener serverAddr string ) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 4fb2a733..5996a534 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -377,7 +377,7 @@ var _ = Describe("Timeout tests", func() { Context("faulty packet conns", func() { const handshakeTimeout = time.Second / 2 - runServer := func(ln quic.Listener) error { + runServer := func(ln *quic.Listener) error { conn, err := ln.Accept(context.Background()) if err != nil { return err diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index eb062de3..377fbe1e 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -61,7 +61,7 @@ var _ = Describe("Handshake tests", func() { quicClientConf := addTracers(protocol.PerspectiveClient, getQuicConfig(nil)) quicServerConf := addTracers(protocol.PerspectiveServer, getQuicConfig(nil)) - serverChan := make(chan quic.Listener) + serverChan := make(chan *quic.Listener) go func() { defer GinkgoRecover() ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index a809d9d3..9253b701 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -18,7 +18,7 @@ var _ = Describe("Unidirectional Streams", func() { const numStreams = 500 var ( - server quic.Listener + server *quic.Listener serverAddr string ) diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index a38edf30..ad3b7f7a 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -45,7 +45,7 @@ func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src lo } var _ = Describe("Handshake tests", func() { - startServer := func(tlsConf *tls.Config, conf *quic.Config) (quic.Listener, func()) { + startServer := func(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, func()) { server, err := quic.ListenAddr("localhost:0", tlsConf, conf) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index b700e7c1..757a71ea 100644 --- a/interface.go +++ b/interface.go @@ -346,16 +346,6 @@ type ConnectionState struct { Version VersionNumber } -// A Listener for incoming QUIC connections -type Listener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new connections. It should be called in a loop. - Accept(context.Context) (Connection, error) -} - // An EarlyListener listens for incoming QUIC connections, // and returns them before the handshake completes. type EarlyListener interface { diff --git a/server.go b/server.go index d984beea..0c8857aa 100644 --- a/server.go +++ b/server.go @@ -114,10 +114,28 @@ type baseServer struct { logger utils.Logger } -var ( - _ Listener = &baseServer{} - _ unknownPacketHandler = &baseServer{} -) +var _ unknownPacketHandler = &baseServer{} + +// A Listener listens for incoming QUIC connections. +// It returns connections once the handshake has completed. +type Listener struct { + baseServer *baseServer +} + +// Accept returns new connections. It should be called in a loop. +func (l *Listener) Accept(ctx context.Context) (Connection, error) { + return l.baseServer.Accept(ctx) +} + +// Close the server. All active connections will be closed. +func (l *Listener) Close() error { + return l.baseServer.Close() +} + +// Addr returns the local network address that the server is listening on. +func (l *Listener) Addr() net.Addr { + return l.baseServer.Addr() +} type earlyServer struct{ *baseServer } @@ -130,8 +148,12 @@ func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { // ListenAddr creates a QUIC server listening on a given address. // The tls.Config must not be nil and must contain a certificate configuration. // The quic.Config may be nil, in that case the default values will be used. -func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { - return listenAddr(addr, tlsConf, config, false) +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { + s, err := listenAddr(addr, tlsConf, config, false) + if err != nil { + return nil, err + } + return &Listener{baseServer: s}, nil } // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. @@ -170,8 +192,12 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo // The tls.Config must not be nil and must contain a certificate configuration. // Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. -func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { - return listen(conn, tlsConf, config, false) +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { + s, err := listen(conn, tlsConf, config, false) + if err != nil { + return nil, err + } + return &Listener{baseServer: s}, nil } // ListenEarly works like Listen, but it returns connections before the handshake completes. diff --git a/server_test.go b/server_test.go index a349efc1..8b843a13 100644 --- a/server_test.go +++ b/server_test.go @@ -120,7 +120,7 @@ var _ = Describe("Server", func() { It("fills in default values if options are not set in the Config", func() { ln, err := Listen(conn, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) - server := ln.(*baseServer) + server := ln.baseServer Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -143,7 +143,7 @@ var _ = Describe("Server", func() { } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) - server := ln.(*baseServer) + server := ln.baseServer Expect(server.connHandler).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) @@ -187,7 +187,7 @@ var _ = Describe("Server", func() { tracer = mocklogging.NewMockTracer(mockCtrl) ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) Expect(err).ToNot(HaveOccurred()) - serv = ln.(*baseServer) + serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) serv.connHandler = phm })