don't remove closed connections from the server's accept queue (#4245)

This commit is contained in:
Marten Seemann 2024-01-18 22:45:38 -08:00 committed by GitHub
parent cb1775a08a
commit 594440b04c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 27 additions and 81 deletions

View file

@ -382,20 +382,34 @@ var _ = Describe("Handshake tests", func() {
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
// Now close the one of the connection that are waiting to be accepted.
// This should free one spot in the queue.
Expect(firstConn.CloseWithError(0, ""))
const appErrCode quic.ApplicationErrorCode = 12345
Expect(firstConn.CloseWithError(appErrCode, ""))
Eventually(firstConn.Context().Done()).Should(BeClosed())
time.Sleep(scaleDuration(200 * time.Millisecond))
// dial again, and expect that this dial succeeds
_, err = dial()
Expect(err).ToNot(HaveOccurred())
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
// dial again, and expect that this fails again
_, err = dial()
Expect(err).To(HaveOccurred())
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
// now accept all connections
var closedConn quic.Connection
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
if conn.Context().Err() != nil {
if closedConn != nil {
Fail("only expected a single closed connection")
}
closedConn = conn
}
}
Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection
_, err = closedConn.AcceptStream(context.Background())
var appErr *quic.ApplicationError
Expect(errors.As(err, &appErr)).To(BeTrue())
Expect(appErr.ErrorCode).To(Equal(appErrCode))
})
It("closes handshaking connections when the server is closed", func() {

View file

@ -7,7 +7,6 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/handshake"
@ -111,8 +110,7 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
connQueue chan quicConn
connQueueLen int32 // to be used as an atomic
connQueue chan quicConn
tracer *logging.Tracer
@ -251,7 +249,7 @@ func newServer(
maxTokenAge: maxTokenAge,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
@ -322,7 +320,6 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-s.connQueue:
atomic.AddInt32(&s.connQueueLen, -1)
return conn, nil
case <-s.errorChan:
return nil, s.closeErr
@ -605,7 +602,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
if queueLen := len(s.connQueue); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
@ -713,13 +710,10 @@ func (s *baseServer) handleNewConn(conn quicConn) {
}
}
atomic.AddInt32(&s.connQueueLen, 1)
select {
case s.connQueue <- conn:
// blocks until the connection is accepted
case <-connCtx.Done():
atomic.AddInt32(&s.connQueueLen, -1)
// don't pass connections that were already closed to Accept()
default:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
}
}

View file

@ -699,68 +699,6 @@ var _ = Describe("Server", func() {
serv.handlePacket(p)
Eventually(done).Should(BeClosed())
})
It("doesn't accept new connections if they were closed in the mean time", func() {
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{})
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
conn.EXPECT().handlePacket(p)
conn.EXPECT().run()
conn.EXPECT().Context().Return(ctx)
c := make(chan struct{})
close(c)
conn.EXPECT().HandshakeComplete().Return(c)
close(connCreated)
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
Eventually(connCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept(context.Background())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})
Context("token validation", func() {
@ -1289,7 +1227,7 @@ var _ = Describe("Server", func() {
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
}
Eventually(func() int32 { return atomic.LoadInt32(&serv.baseServer.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize))
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)