diff --git a/closed_conn.go b/closed_conn.go index 49fb83b6..fbee319c 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -41,8 +41,9 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) { c.sendPacket(p.remoteAddr, p.info) } -func (c *closedLocalConn) destroy(error) {} -func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } +func (c *closedLocalConn) destroy(error) {} +func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {} +func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } // A closedRemoteConn is a connection that was closed remotely. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. @@ -57,6 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler { return &closedRemoteConn{perspective: pers} } -func (s *closedRemoteConn) handlePacket(receivedPacket) {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } +func (c *closedRemoteConn) handlePacket(receivedPacket) {} +func (c *closedRemoteConn) destroy(error) {} +func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {} +func (c *closedRemoteConn) getPerspective() protocol.Perspective { return c.perspective } diff --git a/connection.go b/connection.go index e32547c1..895c2524 100644 --- a/connection.go +++ b/connection.go @@ -1581,6 +1581,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro return nil } +func (s *connection) closeWithTransportError(code TransportErrorCode) { + s.closeLocal(&qerr.TransportError{ErrorCode: code}) + <-s.ctx.Done() +} + func (s *connection) handleCloseError(closeErr *closeError) { e := closeErr.err if e == nil { diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index f8996617..49ecf121 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -397,6 +397,41 @@ var _ = Describe("Handshake tests", func() { Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) }) + + It("closes handshaking connections when the server is closed", func() { + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + tr := quic.Transport{ + Conn: udpConn, + } + defer tr.Close() + tlsConf := &tls.Config{} + done := make(chan struct{}) + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + <-done + return nil, errors.New("closed") + } + ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + + errChan := make(chan error, 1) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + go func() { + defer GinkgoRecover() + _, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) + errChan <- err + }() + time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued + Expect(ln.Close()).To(Succeed()) + close(done) + err = <-errChan + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) + }) }) Context("ALPN", func() { diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 742dd19c..a031c1f5 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -12,6 +12,7 @@ import ( reflect "reflect" protocol "github.com/quic-go/quic-go/internal/protocol" + qerr "github.com/quic-go/quic-go/internal/qerr" gomock "go.uber.org/mock/gomock" ) @@ -38,6 +39,42 @@ func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { return m.recorder } +// closeWithTransportError mocks base method. +func (m *MockPacketHandler) closeWithTransportError(arg0 qerr.TransportErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeWithTransportError", arg0) +} + +// closeWithTransportError indicates an expected call of closeWithTransportError. +func (mr *MockPacketHandlerMockRecorder) closeWithTransportError(arg0 any) *PacketHandlercloseWithTransportErrorCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockPacketHandler)(nil).closeWithTransportError), arg0) + return &PacketHandlercloseWithTransportErrorCall{Call: call} +} + +// PacketHandlercloseWithTransportErrorCall wrap *gomock.Call +type PacketHandlercloseWithTransportErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *PacketHandlercloseWithTransportErrorCall) Return() *PacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *PacketHandlercloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *PacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *PacketHandlercloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *PacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // destroy mocks base method. func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.T.Helper() diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index f9ab60bb..ce0d6192 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -656,6 +656,42 @@ func (c *QUICConnSendDatagramCall) DoAndReturn(f func([]byte) error) *QUICConnSe return c } +// closeWithTransportError mocks base method. +func (m *MockQUICConn) closeWithTransportError(arg0 qerr.TransportErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeWithTransportError", arg0) +} + +// closeWithTransportError indicates an expected call of closeWithTransportError. +func (mr *MockQUICConnMockRecorder) closeWithTransportError(arg0 any) *QUICConncloseWithTransportErrorCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockQUICConn)(nil).closeWithTransportError), arg0) + return &QUICConncloseWithTransportErrorCall{Call: call} +} + +// QUICConncloseWithTransportErrorCall wrap *gomock.Call +type QUICConncloseWithTransportErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *QUICConncloseWithTransportErrorCall) Return() *QUICConncloseWithTransportErrorCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *QUICConncloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *QUICConncloseWithTransportErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *QUICConncloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *QUICConncloseWithTransportErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // destroy mocks base method. func (m *MockQUICConn) destroy(arg0 error) { m.ctrl.T.Helper() diff --git a/server.go b/server.go index 8c05600d..1ef9a25c 100644 --- a/server.go +++ b/server.go @@ -25,6 +25,7 @@ var ErrServerClosed = errors.New("quic: server closed") type packetHandler interface { handlePacket(receivedPacket) destroy(error) + closeWithTransportError(qerr.TransportErrorCode) getPerspective() protocol.Perspective } @@ -44,6 +45,7 @@ type quicConn interface { getPerspective() protocol.Perspective run() error destroy(error) + closeWithTransportError(TransportErrorCode) } type zeroRTTQueue struct { @@ -693,7 +695,7 @@ func (s *baseServer) handleNewConn(conn quicConn) { // wait until the early connection is ready, the handshake fails, or the server is closed select { case <-s.errorChan: - conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) + conn.closeWithTransportError(ConnectionRefused) return case <-conn.earlyConnReady(): case <-connCtx.Done(): @@ -703,7 +705,7 @@ func (s *baseServer) handleNewConn(conn quicConn) { // wait until the handshake is complete (or fails) select { case <-s.errorChan: - conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}) + conn.closeWithTransportError(ConnectionRefused) return case <-conn.HandshakeComplete(): case <-connCtx.Done(): diff --git a/server_test.go b/server_test.go index 2e6ee018..c3bdc5a6 100644 --- a/server_test.go +++ b/server_test.go @@ -327,7 +327,7 @@ var _ = Describe("Server", func() { Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) // shutdown - conn.EXPECT().destroy(gomock.Any()) + conn.EXPECT().closeWithTransportError(gomock.Any()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { @@ -530,7 +530,7 @@ var _ = Describe("Server", func() { Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) // shutdown - conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) + conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) }) It("drops packets if the receive queue is full", func() { @@ -570,7 +570,7 @@ var _ = Describe("Server", func() { conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1) // shutdown - conn.EXPECT().destroy(gomock.Any()).MaxTimes(1) + conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) return conn } @@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() { ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()) - conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) }) + conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().run().MaxTimes(1) conn.EXPECT().Context().Return(context.Background()) @@ -1468,7 +1468,7 @@ var _ = Describe("Server", func() { conn.EXPECT().Context().Return(context.Background()) close(called) // shutdown - conn.EXPECT().destroy(gomock.Any()) + conn.EXPECT().closeWithTransportError(gomock.Any()) return conn }