diff --git a/Changelog.md b/Changelog.md index 64f47a39..f365d1f6 100644 --- a/Changelog.md +++ b/Changelog.md @@ -8,6 +8,7 @@ - Enforce application protocol negotiation (via `tls.Config.NextProtos`). - Use a varint for error codes. - Add support for [quic-trace](https://github.com/google/quic-trace). +- Add a context to `Listener.Accept`. ## v0.11.0 (2019-04-05) diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index 6685115e..846fb84c 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -2,6 +2,7 @@ package benchmark import ( "bytes" + "context" "crypto/tls" "fmt" "io" @@ -44,7 +45,7 @@ func init() { ) Expect(err).ToNot(HaveOccurred()) serverAddr <- ln.Addr() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // wait for the client to complete the handshake before sending the data // this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets diff --git a/example/echo/echo.go b/example/echo/echo.go index 44a93103..77c6522f 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -35,7 +36,7 @@ func echoServer() error { if err != nil { return err } - sess, err := listener.Accept() + sess, err := listener.Accept(context.Background()) if err != nil { return err } diff --git a/http3/server.go b/http3/server.go index 917b0bae..689ec41f 100644 --- a/http3/server.go +++ b/http3/server.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -114,7 +115,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { s.listenerMutex.Unlock() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return err } diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index dffdea20..d2756d4a 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io" "io/ioutil" @@ -33,7 +34,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { @@ -196,7 +197,7 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { @@ -227,7 +228,7 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { @@ -265,7 +266,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { @@ -339,7 +340,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 8128686a..bc9098d1 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -25,7 +26,7 @@ var _ = Describe("Connection ID lengths tests", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return } diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 4b28e811..829d7ec0 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -27,7 +28,7 @@ var _ = Describe("Stream deadline tests", func() { acceptedStream := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) serverStr, err = sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index bf3265ae..e8bdbba4 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "math/rand" "net" @@ -88,7 +89,7 @@ var _ = Describe("Drop Tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 5a6a8dff..28dd95e8 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" mrand "math/rand" "net" @@ -57,7 +58,7 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) defer sess.Close() str, err := sess.AcceptStream() @@ -92,7 +93,7 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -126,7 +127,7 @@ var _ = Describe("Handshake drop tests", func() { serverSessionChan := make(chan quic.Session) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) serverSessionChan <- sess }() diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 16a993a5..01053295 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -56,8 +57,7 @@ var _ = Describe("Handshake RTT tests", func() { defer GinkgoRecover() defer close(acceptStopped) for { - _, err := server.Accept() - if err != nil { + if _, err := server.Accept(context.Background()); err != nil { return } } diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 260b2bff..17e2b90a 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -50,7 +51,7 @@ var _ = Describe("Handshake tests", func() { defer GinkgoRecover() defer close(acceptStopped) for { - if _, err := server.Accept(); err != nil { + if _, err := server.Accept(context.Background()); err != nil { return } } @@ -236,7 +237,7 @@ var _ = Describe("Handshake tests", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy)) // now accept one session, freeing one spot in the queue - _, err = server.Accept() + _, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // dial again, and expect that this dial succeeds sess, err := dial() @@ -289,7 +290,7 @@ var _ = Describe("Handshake tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) cs := sess.ConnectionState() Expect(cs.NegotiatedProtocol).To(Equal(alpn)) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 3e818074..f6b9f0ab 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -25,7 +26,7 @@ var _ = Describe("Multiplexing", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) if err != nil { return } diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 89059fce..fad7fe09 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "crypto/tls" "fmt" "net" @@ -55,11 +56,11 @@ var _ = Describe("TLS session resumption", func() { go func() { defer close(done) defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(sess.ConnectionState().DidResume).To(BeFalse()) - sess, err = server.Accept() + sess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(sess.ConnectionState().DidResume).To(BeTrue()) }() diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index 34a54e68..b3782433 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -42,7 +43,7 @@ var _ = Describe("non-zero RTT", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 22745cfe..85826d6b 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "math/rand" "net" @@ -33,7 +34,7 @@ var _ = Describe("Stateless Resets", func() { go func() { defer GinkgoRecover() - sess, err := ln.Accept() + sess, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -86,7 +87,7 @@ var _ = Describe("Stateless Resets", func() { acceptStopped := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := ln2.Accept() + _, err := ln2.Accept(context.Background()) Expect(err).To(HaveOccurred()) close(acceptStopped) }() diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 72df4d2f..bc2fbb93 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -92,7 +93,7 @@ var _ = Describe("Bidirectional streams", func() { go func() { defer GinkgoRecover() var err error - sess, err = server.Accept() + sess, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(sess) }() @@ -109,7 +110,7 @@ var _ = Describe("Bidirectional streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(sess) sess.Close() @@ -129,7 +130,7 @@ var _ = Describe("Bidirectional streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 3c50a277..eaa23215 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -70,7 +70,7 @@ var _ = Describe("Timeout tests", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) @@ -146,7 +146,7 @@ var _ = Describe("Timeout tests", func() { serverSessionClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) sess.AcceptStream() // blocks until the session is closed close(serverSessionClosed) @@ -187,7 +187,7 @@ var _ = Describe("Timeout tests", func() { serverSessionClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) sess.AcceptStream() // blocks until the session is closed close(serverSessionClosed) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 1d5a044c..67bad1d2 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -1,6 +1,7 @@ package self_test import ( + "context" "fmt" "io/ioutil" "net" @@ -72,7 +73,7 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(sess) sess.Close() @@ -91,7 +92,7 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(sess) }() @@ -109,7 +110,7 @@ var _ = Describe("Unidirectional Streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept() + sess, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { diff --git a/interface.go b/interface.go index 5362e330..f8dc9471 100644 --- a/interface.go +++ b/interface.go @@ -232,5 +232,5 @@ type Listener interface { // Addr returns the local network addr that the server is listening on. Addr() net.Addr // Accept returns new sessions. It should be called in a loop. - Accept() (Session, error) + Accept(context.Context) (Session, error) } diff --git a/server.go b/server.go index c080d063..f6510bdc 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -284,9 +285,11 @@ func populateServerConfig(config *Config) *Config { } // Accept returns newly openend sessions -func (s *server) Accept() (Session, error) { +func (s *server) Accept(ctx context.Context) (Session, error) { var sess Session select { + case <-ctx.Done(): + return nil, ctx.Err() case sess = <-s.sessionQueue: return sess, nil case <-s.errorChan: diff --git a/server_test.go b/server_test.go index 008d7dcc..0a966688 100644 --- a/server_test.go +++ b/server_test.go @@ -434,7 +434,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.Accept() + serv.Accept(context.Background()) close(done) }() Consistently(done).ShouldNot(BeClosed()) @@ -461,7 +461,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := serv.Accept() + _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) close(done) }() @@ -474,18 +474,33 @@ var _ = Describe("Server", func() { testErr := errors.New("test err") serv.setCloseError(testErr) for i := 0; i < 3; i++ { - _, err := serv.Accept() + _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) } }) + It("returns when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + It("accepts new sessions when the handshake completes", func() { sess := NewMockQuicSession(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() - s, err := serv.Accept() + s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(s).To(Equal(sess)) close(done)