add a context to Listener.Accept

This commit is contained in:
Marten Seemann 2019-05-28 16:03:38 +02:00
parent 8dbe1684be
commit 12bce1caaa
21 changed files with 74 additions and 40 deletions

View file

@ -8,6 +8,7 @@
- Enforce application protocol negotiation (via `tls.Config.NextProtos`). - Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes. - Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace). - Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`.
## v0.11.0 (2019-04-05) ## v0.11.0 (2019-04-05)

View file

@ -2,6 +2,7 @@ package benchmark
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -44,7 +45,7 @@ func init() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr() serverAddr <- ln.Addr()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// wait for the client to complete the handshake before sending the data // wait for the client to complete the handshake before sending the data
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets // this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
@ -35,7 +36,7 @@ func echoServer() error {
if err != nil { if err != nil {
return err return err
} }
sess, err := listener.Accept() sess, err := listener.Accept(context.Background())
if err != nil { if err != nil {
return err return err
} }

View file

@ -2,6 +2,7 @@ package http3
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -114,7 +115,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
for { for {
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -33,7 +34,7 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover() defer GinkgoRecover()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -196,7 +197,7 @@ var _ = Describe("Stream Cancelations", func() {
var canceledCounter int32 var canceledCounter int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -227,7 +228,7 @@ var _ = Describe("Stream Cancelations", func() {
var canceledCounter int32 var canceledCounter int32
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -265,7 +266,7 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover() defer GinkgoRecover()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {
@ -339,7 +340,7 @@ var _ = Describe("Stream Cancelations", func() {
defer GinkgoRecover() defer GinkgoRecover()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numStreams) wg.Add(numStreams)
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
go func() { go func() {

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
@ -25,7 +26,7 @@ var _ = Describe("Connection ID lengths tests", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
for { for {
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
if err != nil { if err != nil {
return return
} }

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -27,7 +28,7 @@ var _ = Describe("Stream deadline tests", func() {
acceptedStream := make(chan struct{}) acceptedStream := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverStr, err = sess.AcceptStream() serverStr, err = sess.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@ -88,7 +89,7 @@ var _ = Describe("Drop Tests", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream() str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
"net" "net"
@ -57,7 +58,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session) serverSessionChan := make(chan quic.Session)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer sess.Close() defer sess.Close()
str, err := sess.AcceptStream() str, err := sess.AcceptStream()
@ -92,7 +93,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session) serverSessionChan := make(chan quic.Session)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream() str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -126,7 +127,7 @@ var _ = Describe("Handshake drop tests", func() {
serverSessionChan := make(chan quic.Session) serverSessionChan := make(chan quic.Session)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess serverSessionChan <- sess
}() }()

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@ -56,8 +57,7 @@ var _ = Describe("Handshake RTT tests", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer close(acceptStopped) defer close(acceptStopped)
for { for {
_, err := server.Accept() if _, err := server.Accept(context.Background()); err != nil {
if err != nil {
return return
} }
} }

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@ -50,7 +51,7 @@ var _ = Describe("Handshake tests", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer close(acceptStopped) defer close(acceptStopped)
for { for {
if _, err := server.Accept(); err != nil { if _, err := server.Accept(context.Background()); err != nil {
return return
} }
} }
@ -236,7 +237,7 @@ var _ = Describe("Handshake tests", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ServerBusy))
// now accept one session, freeing one spot in the queue // now accept one session, freeing one spot in the queue
_, err = server.Accept() _, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// dial again, and expect that this dial succeeds // dial again, and expect that this dial succeeds
sess, err := dial() sess, err := dial()
@ -289,7 +290,7 @@ var _ = Describe("Handshake tests", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cs := sess.ConnectionState() cs := sess.ConnectionState()
Expect(cs.NegotiatedProtocol).To(Equal(alpn)) Expect(cs.NegotiatedProtocol).To(Equal(alpn))

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -25,7 +26,7 @@ var _ = Describe("Multiplexing", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
for { for {
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
if err != nil { if err != nil {
return return
} }

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@ -55,11 +56,11 @@ var _ = Describe("TLS session resumption", func() {
go func() { go func() {
defer close(done) defer close(done)
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().DidResume).To(BeFalse()) Expect(sess.ConnectionState().DidResume).To(BeFalse())
sess, err = server.Accept() sess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(sess.ConnectionState().DidResume).To(BeTrue()) Expect(sess.ConnectionState().DidResume).To(BeTrue())
}() }()

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -42,7 +43,7 @@ var _ = Describe("non-zero RTT", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream() str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@ -33,7 +34,7 @@ var _ = Describe("Stateless Resets", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := ln.Accept() sess, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream() str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -86,7 +87,7 @@ var _ = Describe("Stateless Resets", func() {
acceptStopped := make(chan struct{}) acceptStopped := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := ln2.Accept() _, err := ln2.Accept(context.Background())
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
close(acceptStopped) close(acceptStopped)
}() }()

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -92,7 +93,7 @@ var _ = Describe("Bidirectional streams", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error var err error
sess, err = server.Accept() sess, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess) runReceivingPeer(sess)
}() }()
@ -109,7 +110,7 @@ var _ = Describe("Bidirectional streams", func() {
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess) runSendingPeer(sess)
sess.Close() sess.Close()
@ -129,7 +130,7 @@ var _ = Describe("Bidirectional streams", func() {
done1 := make(chan struct{}) done1 := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {

View file

@ -70,7 +70,7 @@ var _ = Describe("Timeout tests", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream() str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -146,7 +146,7 @@ var _ = Describe("Timeout tests", func() {
serverSessionClosed := make(chan struct{}) serverSessionClosed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed sess.AcceptStream() // blocks until the session is closed
close(serverSessionClosed) close(serverSessionClosed)
@ -187,7 +187,7 @@ var _ = Describe("Timeout tests", func() {
serverSessionClosed := make(chan struct{}) serverSessionClosed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sess.AcceptStream() // blocks until the session is closed sess.AcceptStream() // blocks until the session is closed
close(serverSessionClosed) close(serverSessionClosed)

View file

@ -1,6 +1,7 @@
package self_test package self_test
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -72,7 +73,7 @@ var _ = Describe("Unidirectional Streams", func() {
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess) runReceivingPeer(sess)
sess.Close() sess.Close()
@ -91,7 +92,7 @@ var _ = Describe("Unidirectional Streams", func() {
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess) runSendingPeer(sess)
}() }()
@ -109,7 +110,7 @@ var _ = Describe("Unidirectional Streams", func() {
done1 := make(chan struct{}) done1 := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
sess, err := server.Accept() sess, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {

View file

@ -232,5 +232,5 @@ type Listener interface {
// Addr returns the local network addr that the server is listening on. // Addr returns the local network addr that the server is listening on.
Addr() net.Addr Addr() net.Addr
// Accept returns new sessions. It should be called in a loop. // Accept returns new sessions. It should be called in a loop.
Accept() (Session, error) Accept(context.Context) (Session, error)
} }

View file

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -284,9 +285,11 @@ func populateServerConfig(config *Config) *Config {
} }
// Accept returns newly openend sessions // Accept returns newly openend sessions
func (s *server) Accept() (Session, error) { func (s *server) Accept(ctx context.Context) (Session, error) {
var sess Session var sess Session
select { select {
case <-ctx.Done():
return nil, ctx.Err()
case sess = <-s.sessionQueue: case sess = <-s.sessionQueue:
return sess, nil return sess, nil
case <-s.errorChan: case <-s.errorChan:

View file

@ -434,7 +434,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.Accept() serv.Accept(context.Background())
close(done) close(done)
}() }()
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -461,7 +461,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := serv.Accept() _, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
@ -474,18 +474,33 @@ var _ = Describe("Server", func() {
testErr := errors.New("test err") testErr := errors.New("test err")
serv.setCloseError(testErr) serv.setCloseError(testErr)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := serv.Accept() _, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
} }
}) })
It("returns when the context is canceled", func() {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(ctx)
Expect(err).To(MatchError("context canceled"))
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
cancel()
Eventually(done).Should(BeClosed())
})
It("accepts new sessions when the handshake completes", func() { It("accepts new sessions when the handshake completes", func() {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
s, err := serv.Accept() s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(sess)) Expect(s).To(Equal(sess))
close(done) close(done)