mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
send out the CONNECTION_REFUSED error when refusing a connection (#4250)
So far, we used Connection.destroy, which destroys a connection without sending out a CONNECTION_CLOSE frame. This is useful (for example) when receiving a stateless reset, but it's not what we want when the server refuses an incoming connection. In this case, we want to send out a packet with a CONNECTION_CLOSE frame to inform the client that the connection attempt is being rejected.
This commit is contained in:
parent
b3eb375bc1
commit
cb1775a08a
7 changed files with 129 additions and 12 deletions
|
@ -41,8 +41,9 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
|||
c.sendPacket(p.remoteAddr, p.info)
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
|
||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
|
||||
// A closedRemoteConn is a connection that was closed remotely.
|
||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||
|
@ -57,6 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
|||
return &closedRemoteConn{perspective: pers}
|
||||
}
|
||||
|
||||
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (s *closedRemoteConn) destroy(error) {}
|
||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
||||
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (c *closedRemoteConn) destroy(error) {}
|
||||
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
|
||||
func (c *closedRemoteConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
|
|
|
@ -1581,6 +1581,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *connection) closeWithTransportError(code TransportErrorCode) {
|
||||
s.closeLocal(&qerr.TransportError{ErrorCode: code})
|
||||
<-s.ctx.Done()
|
||||
}
|
||||
|
||||
func (s *connection) handleCloseError(closeErr *closeError) {
|
||||
e := closeErr.err
|
||||
if e == nil {
|
||||
|
|
|
@ -397,6 +397,41 @@ var _ = Describe("Handshake tests", func() {
|
|||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||
})
|
||||
|
||||
It("closes handshaking connections when the server is closed", func() {
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := quic.Transport{
|
||||
Conn: udpConn,
|
||||
}
|
||||
defer tr.Close()
|
||||
tlsConf := &tls.Config{}
|
||||
done := make(chan struct{})
|
||||
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
<-done
|
||||
return nil, errors.New("closed")
|
||||
}
|
||||
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
errChan <- err
|
||||
}()
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
close(done)
|
||||
err = <-errChan
|
||||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ALPN", func() {
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
protocol "github.com/quic-go/quic-go/internal/protocol"
|
||||
qerr "github.com/quic-go/quic-go/internal/qerr"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
|
@ -38,6 +39,42 @@ func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// closeWithTransportError mocks base method.
|
||||
func (m *MockPacketHandler) closeWithTransportError(arg0 qerr.TransportErrorCode) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "closeWithTransportError", arg0)
|
||||
}
|
||||
|
||||
// closeWithTransportError indicates an expected call of closeWithTransportError.
|
||||
func (mr *MockPacketHandlerMockRecorder) closeWithTransportError(arg0 any) *PacketHandlercloseWithTransportErrorCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockPacketHandler)(nil).closeWithTransportError), arg0)
|
||||
return &PacketHandlercloseWithTransportErrorCall{Call: call}
|
||||
}
|
||||
|
||||
// PacketHandlercloseWithTransportErrorCall wrap *gomock.Call
|
||||
type PacketHandlercloseWithTransportErrorCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *PacketHandlercloseWithTransportErrorCall) Return() *PacketHandlercloseWithTransportErrorCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *PacketHandlercloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *PacketHandlercloseWithTransportErrorCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *PacketHandlercloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *PacketHandlercloseWithTransportErrorCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// destroy mocks base method.
|
||||
func (m *MockPacketHandler) destroy(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -656,6 +656,42 @@ func (c *QUICConnSendDatagramCall) DoAndReturn(f func([]byte) error) *QUICConnSe
|
|||
return c
|
||||
}
|
||||
|
||||
// closeWithTransportError mocks base method.
|
||||
func (m *MockQUICConn) closeWithTransportError(arg0 qerr.TransportErrorCode) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "closeWithTransportError", arg0)
|
||||
}
|
||||
|
||||
// closeWithTransportError indicates an expected call of closeWithTransportError.
|
||||
func (mr *MockQUICConnMockRecorder) closeWithTransportError(arg0 any) *QUICConncloseWithTransportErrorCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockQUICConn)(nil).closeWithTransportError), arg0)
|
||||
return &QUICConncloseWithTransportErrorCall{Call: call}
|
||||
}
|
||||
|
||||
// QUICConncloseWithTransportErrorCall wrap *gomock.Call
|
||||
type QUICConncloseWithTransportErrorCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *QUICConncloseWithTransportErrorCall) Return() *QUICConncloseWithTransportErrorCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *QUICConncloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *QUICConncloseWithTransportErrorCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *QUICConncloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *QUICConncloseWithTransportErrorCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// destroy mocks base method.
|
||||
func (m *MockQUICConn) destroy(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -25,6 +25,7 @@ var ErrServerClosed = errors.New("quic: server closed")
|
|||
type packetHandler interface {
|
||||
handlePacket(receivedPacket)
|
||||
destroy(error)
|
||||
closeWithTransportError(qerr.TransportErrorCode)
|
||||
getPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
|
@ -44,6 +45,7 @@ type quicConn interface {
|
|||
getPerspective() protocol.Perspective
|
||||
run() error
|
||||
destroy(error)
|
||||
closeWithTransportError(TransportErrorCode)
|
||||
}
|
||||
|
||||
type zeroRTTQueue struct {
|
||||
|
@ -693,7 +695,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
|
|||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return
|
||||
case <-conn.earlyConnReady():
|
||||
case <-connCtx.Done():
|
||||
|
@ -703,7 +705,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
|
|||
// wait until the handshake is complete (or fails)
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return
|
||||
case <-conn.HandshakeComplete():
|
||||
case <-connCtx.Done():
|
||||
|
|
|
@ -327,7 +327,7 @@ var _ = Describe("Server", func() {
|
|||
Eventually(run).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any())
|
||||
})
|
||||
|
||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||
|
@ -530,7 +530,7 @@ var _ = Describe("Server", func() {
|
|||
Eventually(run).Should(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
|
||||
})
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
|
@ -570,7 +570,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
|
||||
return conn
|
||||
}
|
||||
|
||||
|
@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() {
|
|||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
|
||||
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().run().MaxTimes(1)
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
|
@ -1468,7 +1468,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background())
|
||||
close(called)
|
||||
// shutdown
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
conn.EXPECT().closeWithTransportError(gomock.Any())
|
||||
return conn
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue