mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
don't close established connections on Listener.Close, when using a Transport (#4072)
* don't close established connections on Listener.Close * only close once
This commit is contained in:
parent
ef800d6f71
commit
dda63b90eb
12 changed files with 136 additions and 151 deletions
|
@ -14,7 +14,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Stream deadline tests", func() {
|
var _ = Describe("Stream deadline tests", func() {
|
||||||
setup := func() (*quic.Listener, quic.Stream, quic.Stream) {
|
setup := func() (serverStr, clientStr quic.Stream, close func()) {
|
||||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
strChan := make(chan quic.SendStream)
|
strChan := make(chan quic.SendStream)
|
||||||
|
@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
clientStr, err := conn.OpenStream()
|
clientStr, err = conn.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
|
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var serverStr quic.Stream
|
|
||||||
Eventually(strChan).Should(Receive(&serverStr))
|
Eventually(strChan).Should(Receive(&serverStr))
|
||||||
return server, serverStr, clientStr
|
return serverStr, clientStr, func() {
|
||||||
|
Expect(server.Close()).To(Succeed())
|
||||||
|
Expect(conn.CloseWithError(0, "")).To(Succeed())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Context("read deadlines", func() {
|
Context("read deadlines", func() {
|
||||||
It("completes a transfer when the deadline is set", func() {
|
It("completes a transfer when the deadline is set", func() {
|
||||||
server, serverStr, clientStr := setup()
|
serverStr, clientStr, closeFn := setup()
|
||||||
defer server.Close()
|
defer closeFn()
|
||||||
|
|
||||||
const timeout = time.Millisecond
|
const timeout = time.Millisecond
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("completes a transfer when the deadline is set concurrently", func() {
|
It("completes a transfer when the deadline is set concurrently", func() {
|
||||||
server, serverStr, clientStr := setup()
|
serverStr, clientStr, closeFn := setup()
|
||||||
defer server.Close()
|
defer closeFn()
|
||||||
|
|
||||||
const timeout = time.Millisecond
|
const timeout = time.Millisecond
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() {
|
||||||
|
|
||||||
Context("write deadlines", func() {
|
Context("write deadlines", func() {
|
||||||
It("completes a transfer when the deadline is set", func() {
|
It("completes a transfer when the deadline is set", func() {
|
||||||
server, serverStr, clientStr := setup()
|
serverStr, clientStr, closeFn := setup()
|
||||||
defer server.Close()
|
defer closeFn()
|
||||||
|
|
||||||
const timeout = time.Millisecond
|
const timeout = time.Millisecond
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("completes a transfer when the deadline is set concurrently", func() {
|
It("completes a transfer when the deadline is set concurrently", func() {
|
||||||
server, serverStr, clientStr := setup()
|
serverStr, clientStr, closeFn := setup()
|
||||||
defer server.Close()
|
defer closeFn()
|
||||||
|
|
||||||
const timeout = time.Millisecond
|
const timeout = time.Millisecond
|
||||||
readDone := make(chan struct{})
|
readDone := make(chan struct{})
|
||||||
|
|
|
@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() {
|
||||||
Context("Certificate validation", func() {
|
Context("Certificate validation", func() {
|
||||||
It("accepts the certificate", func() {
|
It("accepts the certificate", func() {
|
||||||
runServer(getTLSConfig())
|
runServer(getTLSConfig())
|
||||||
_, err := quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
|
It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
|
||||||
|
@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn.CloseWithError(0, "")
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
Expect(server.Addr()).To(Equal(local))
|
Expect(server.Addr()).To(Equal(local))
|
||||||
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
|
||||||
|
@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() {
|
||||||
|
|
||||||
It("works with a long certificate chain", func() {
|
It("works with a long certificate chain", func() {
|
||||||
runServer(getTLSConfigWithLongCertChain())
|
runServer(getTLSConfigWithLongCertChain())
|
||||||
_, err := quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if the server name doesn't match", func() {
|
It("errors if the server name doesn't match", func() {
|
||||||
|
|
|
@ -52,22 +52,23 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||||
tlsConf := getTLSClientConfig()
|
tlsConf := getTLSClientConfig()
|
||||||
tlsConf.ClientSessionCache = cache
|
tlsConf.ClientSessionCache = cache
|
||||||
conn, err := quic.DialAddr(
|
conn1, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn1.CloseWithError(0, "")
|
||||||
var sessionKey string
|
var sessionKey string
|
||||||
Eventually(puts).Should(Receive(&sessionKey))
|
Eventually(puts).Should(Receive(&sessionKey))
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
|
||||||
serverConn, err := server.Accept(context.Background())
|
serverConn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
|
||||||
conn, err = quic.DialAddr(
|
conn2, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
|
@ -75,11 +76,12 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(gets).To(Receive(Equal(sessionKey)))
|
Expect(gets).To(Receive(Equal(sessionKey)))
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
|
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||||
|
|
||||||
serverConn, err = server.Accept(context.Background())
|
serverConn, err = server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
|
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
|
||||||
|
conn2.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't use session resumption, if the config disables it", func() {
|
It("doesn't use session resumption, if the config disables it", func() {
|
||||||
|
@ -94,15 +96,16 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||||
tlsConf := getTLSClientConfig()
|
tlsConf := getTLSClientConfig()
|
||||||
tlsConf.ClientSessionCache = cache
|
tlsConf.ClientSessionCache = cache
|
||||||
conn, err := quic.DialAddr(
|
conn1, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn1.CloseWithError(0, "")
|
||||||
Consistently(puts).ShouldNot(Receive())
|
Consistently(puts).ShouldNot(Receive())
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -110,14 +113,15 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
|
||||||
conn, err = quic.DialAddr(
|
conn2, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
defer conn2.CloseWithError(0, "")
|
||||||
|
|
||||||
serverConn, err = server.Accept(context.Background())
|
serverConn, err = server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -142,7 +146,7 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
|
||||||
tlsConf := getTLSClientConfig()
|
tlsConf := getTLSClientConfig()
|
||||||
tlsConf.ClientSessionCache = cache
|
tlsConf.ClientSessionCache = cache
|
||||||
conn, err := quic.DialAddr(
|
conn1, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
|
@ -150,7 +154,8 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Consistently(puts).ShouldNot(Receive())
|
Consistently(puts).ShouldNot(Receive())
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
defer conn1.CloseWithError(0, "")
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -158,14 +163,15 @@ var _ = Describe("TLS session resumption", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
|
||||||
conn, err = quic.DialAddr(
|
conn2, err := quic.DialAddr(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
tlsConf,
|
tlsConf,
|
||||||
getQuicConfig(nil),
|
getQuicConfig(nil),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
|
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
|
||||||
|
defer conn2.CloseWithError(0, "")
|
||||||
|
|
||||||
serverConn, err = server.Accept(context.Background())
|
serverConn, err = server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
|
serverConnChan := make(chan quic.Connection, 1)
|
||||||
serverConnClosed := make(chan struct{})
|
serverConnClosed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
conn, err := server.Accept(context.Background())
|
conn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
serverConnChan <- conn
|
||||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||||
close(serverConnClosed)
|
close(serverConnClosed)
|
||||||
}()
|
}()
|
||||||
|
@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
Expect(server.Close()).To(Succeed())
|
(<-serverConnChan).CloseWithError(0, "")
|
||||||
Eventually(serverConnClosed).Should(BeClosed())
|
Eventually(serverConnClosed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer proxy.Close()
|
defer proxy.Close()
|
||||||
|
|
||||||
|
serverConnChan := make(chan quic.Connection, 1)
|
||||||
serverConnClosed := make(chan struct{})
|
serverConnClosed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
conn, err := server.Accept(context.Background())
|
conn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
serverConnChan <- conn
|
||||||
<-conn.Context().Done() // block until the connection is closed
|
<-conn.Context().Done() // block until the connection is closed
|
||||||
close(serverConnClosed)
|
close(serverConnClosed)
|
||||||
}()
|
}()
|
||||||
|
@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
Consistently(serverConnClosed).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
Expect(server.Close()).To(Succeed())
|
(<-serverConnChan).CloseWithError(0, "")
|
||||||
Eventually(serverConnClosed).Should(BeClosed())
|
Eventually(serverConnClosed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
|
serverConnChan := make(chan quic.Connection, 1)
|
||||||
serverConnClosed := make(chan struct{})
|
serverConnClosed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
conn, err := server.Accept(context.Background())
|
conn, err := server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
serverConnChan <- conn
|
||||||
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
conn.AcceptStream(context.Background()) // blocks until the connection is closed
|
||||||
close(serverConnClosed)
|
close(serverConnClosed)
|
||||||
}()
|
}()
|
||||||
|
@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() {
|
||||||
_, err = str.Write([]byte("foobar"))
|
_, err = str.Write([]byte("foobar"))
|
||||||
checkTimeoutError(err)
|
checkTimeoutError(err)
|
||||||
|
|
||||||
Expect(server.Close()).To(Succeed())
|
(<-serverConnChan).CloseWithError(0, "")
|
||||||
Eventually(serverConnClosed).Should(BeClosed())
|
Eventually(serverConnClosed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||||
runReceivingPeer(client)
|
runReceivingPeer(client)
|
||||||
<-done1
|
<-done1
|
||||||
<-done2
|
<-done2
|
||||||
|
client.CloseWithError(0, "")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -186,42 +186,6 @@ func (c *PacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *PacketHandle
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseServer mocks base method.
|
|
||||||
func (m *MockPacketHandlerManager) CloseServer() {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "CloseServer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloseServer indicates an expected call of CloseServer.
|
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *PacketHandlerManagerCloseServerCall {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
|
||||||
return &PacketHandlerManagerCloseServerCall{Call: call}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PacketHandlerManagerCloseServerCall wrap *gomock.Call
|
|
||||||
type PacketHandlerManagerCloseServerCall struct {
|
|
||||||
*gomock.Call
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
|
||||||
func (c *PacketHandlerManagerCloseServerCall) Return() *PacketHandlerManagerCloseServerCall {
|
|
||||||
c.Call = c.Call.Return()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
|
||||||
func (c *PacketHandlerManagerCloseServerCall) Do(f func()) *PacketHandlerManagerCloseServerCall {
|
|
||||||
c.Call = c.Call.Do(f)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
|
||||||
func (c *PacketHandlerManagerCloseServerCall) DoAndReturn(f func()) *PacketHandlerManagerCloseServerCall {
|
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get mocks base method.
|
// Get mocks base method.
|
||||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
|
||||||
return handler, ok
|
return handler, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) CloseServer() {
|
|
||||||
h.mutex.Lock()
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for _, handler := range h.handlers {
|
|
||||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(handler packetHandler) {
|
|
||||||
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
|
||||||
handler.shutdown()
|
|
||||||
wg.Done()
|
|
||||||
}(handler)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.mutex.Unlock()
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) Close(e error) {
|
func (h *packetHandlerMap) Close(e error) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
|
|
||||||
|
|
|
@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes the server", func() {
|
|
||||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
conn := NewMockPacketHandler(mockCtrl)
|
|
||||||
if i%2 == 0 {
|
|
||||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
|
||||||
} else {
|
|
||||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
|
||||||
conn.EXPECT().shutdown()
|
|
||||||
}
|
|
||||||
b := make([]byte, 12)
|
|
||||||
rand.Read(b)
|
|
||||||
m.Add(protocol.ParseConnectionID(b), conn)
|
|
||||||
}
|
|
||||||
m.CloseServer()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes", func() {
|
It("closes", func() {
|
||||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||||
testErr := errors.New("shutdown")
|
testErr := errors.New("shutdown")
|
||||||
|
|
62
server.go
62
server.go
|
@ -34,7 +34,6 @@ type packetHandlerManager interface {
|
||||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
||||||
Close(error)
|
Close(error)
|
||||||
CloseServer()
|
|
||||||
connRunner
|
connRunner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,8 +60,6 @@ type rejectedPacket struct {
|
||||||
|
|
||||||
// A Listener of QUIC
|
// A Listener of QUIC
|
||||||
type baseServer struct {
|
type baseServer struct {
|
||||||
mutex sync.Mutex
|
|
||||||
|
|
||||||
disableVersionNegotiation bool
|
disableVersionNegotiation bool
|
||||||
acceptEarlyConns bool
|
acceptEarlyConns bool
|
||||||
|
|
||||||
|
@ -104,10 +101,11 @@ type baseServer struct {
|
||||||
protocol.VersionNumber,
|
protocol.VersionNumber,
|
||||||
) quicConn
|
) quicConn
|
||||||
|
|
||||||
serverError error
|
closeOnce sync.Once
|
||||||
errorChan chan struct{}
|
errorChan chan struct{} // is closed when the server is closed
|
||||||
closed bool
|
closeErr error
|
||||||
running chan struct{} // closed as soon as run() returns
|
running chan struct{} // closed as soon as run() returns
|
||||||
|
|
||||||
versionNegotiationQueue chan receivedPacket
|
versionNegotiationQueue chan receivedPacket
|
||||||
invalidTokenQueue chan rejectedPacket
|
invalidTokenQueue chan rejectedPacket
|
||||||
connectionRefusedQueue chan rejectedPacket
|
connectionRefusedQueue chan rejectedPacket
|
||||||
|
@ -132,7 +130,10 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) {
|
||||||
return l.baseServer.Accept(ctx)
|
return l.baseServer.Accept(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the server. All active connections will be closed.
|
// Close closes the listener.
|
||||||
|
// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
|
||||||
|
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
|
||||||
|
// Closing the listener doesn't have any effect on already established connections.
|
||||||
func (l *Listener) Close() error {
|
func (l *Listener) Close() error {
|
||||||
return l.baseServer.Close()
|
return l.baseServer.Close()
|
||||||
}
|
}
|
||||||
|
@ -321,38 +322,25 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
||||||
atomic.AddInt32(&s.connQueueLen, -1)
|
atomic.AddInt32(&s.connQueueLen, -1)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
case <-s.errorChan:
|
case <-s.errorChan:
|
||||||
return nil, s.serverError
|
return nil, s.closeErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the server
|
|
||||||
func (s *baseServer) Close() error {
|
func (s *baseServer) Close() error {
|
||||||
s.mutex.Lock()
|
s.close(ErrServerClosed, true)
|
||||||
if s.closed {
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if s.serverError == nil {
|
|
||||||
s.serverError = ErrServerClosed
|
|
||||||
}
|
|
||||||
s.closed = true
|
|
||||||
close(s.errorChan)
|
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
<-s.running
|
|
||||||
s.onClose()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) setCloseError(e error) {
|
func (s *baseServer) close(e error, notifyOnClose bool) {
|
||||||
s.mutex.Lock()
|
s.closeOnce.Do(func() {
|
||||||
defer s.mutex.Unlock()
|
s.closeErr = e
|
||||||
if s.closed {
|
close(s.errorChan)
|
||||||
return
|
|
||||||
}
|
<-s.running
|
||||||
s.closed = true
|
if notifyOnClose {
|
||||||
s.serverError = e
|
s.onClose()
|
||||||
close(s.errorChan)
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the server's network address
|
// Addr returns the server's network address
|
||||||
|
@ -701,8 +689,11 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
func (s *baseServer) handleNewConn(conn quicConn) {
|
||||||
connCtx := conn.Context()
|
connCtx := conn.Context()
|
||||||
if s.acceptEarlyConns {
|
if s.acceptEarlyConns {
|
||||||
// wait until the early connection is ready (or the handshake fails)
|
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||||
select {
|
select {
|
||||||
|
case <-s.errorChan:
|
||||||
|
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||||
|
return
|
||||||
case <-conn.earlyConnReady():
|
case <-conn.earlyConnReady():
|
||||||
case <-connCtx.Done():
|
case <-connCtx.Done():
|
||||||
return
|
return
|
||||||
|
@ -710,6 +701,9 @@ func (s *baseServer) handleNewConn(conn quicConn) {
|
||||||
} else {
|
} else {
|
||||||
// wait until the handshake is complete (or fails)
|
// wait until the handshake is complete (or fails)
|
||||||
select {
|
select {
|
||||||
|
case <-s.errorChan:
|
||||||
|
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
||||||
|
return
|
||||||
case <-conn.HandshakeComplete():
|
case <-conn.HandshakeComplete():
|
||||||
case <-connCtx.Done():
|
case <-connCtx.Done():
|
||||||
return
|
return
|
||||||
|
|
|
@ -326,6 +326,8 @@ var _ = Describe("Server", func() {
|
||||||
// make sure we're using a server-generated connection ID
|
// make sure we're using a server-generated connection ID
|
||||||
Eventually(run).Should(BeClosed())
|
Eventually(run).Should(BeClosed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
// shutdown
|
||||||
|
conn.EXPECT().destroy(gomock.Any())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||||
|
@ -527,6 +529,8 @@ var _ = Describe("Server", func() {
|
||||||
// make sure we're using a server-generated connection ID
|
// make sure we're using a server-generated connection ID
|
||||||
Eventually(run).Should(BeClosed())
|
Eventually(run).Should(BeClosed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
// shutdown
|
||||||
|
conn.EXPECT().destroy(gomock.Any())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets if the receive queue is full", func() {
|
It("drops packets if the receive queue is full", func() {
|
||||||
|
@ -565,6 +569,8 @@ var _ = Describe("Server", func() {
|
||||||
conn.EXPECT().run().MaxTimes(1)
|
conn.EXPECT().run().MaxTimes(1)
|
||||||
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
||||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
||||||
|
// shutdown
|
||||||
|
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -956,30 +962,69 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("accepting connections", func() {
|
Context("accepting connections", func() {
|
||||||
It("returns Accept when an error occurs", func() {
|
It("returns Accept when closed", func() {
|
||||||
testErr := errors.New("test err")
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := serv.Accept(context.Background())
|
_, err := serv.Accept(context.Background())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(ErrServerClosed))
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
serv.setCloseError(testErr)
|
serv.Close()
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
serv.onClose() // shutdown
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns immediately, if an error occurred before", func() {
|
It("returns immediately, if an error occurred before", func() {
|
||||||
testErr := errors.New("test err")
|
serv.Close()
|
||||||
serv.setCloseError(testErr)
|
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
_, err := serv.Accept(context.Background())
|
_, err := serv.Accept(context.Background())
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(ErrServerClosed))
|
||||||
}
|
}
|
||||||
serv.onClose() // shutdown
|
})
|
||||||
|
|
||||||
|
It("closes connection that are still handshaking after Close", func() {
|
||||||
|
serv.Close()
|
||||||
|
|
||||||
|
destroyed := make(chan struct{})
|
||||||
|
serv.newConn = func(
|
||||||
|
_ sendConn,
|
||||||
|
_ connRunner,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ *protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ ConnectionIDGenerator,
|
||||||
|
_ protocol.StatelessResetToken,
|
||||||
|
conf *Config,
|
||||||
|
_ *tls.Config,
|
||||||
|
_ *handshake.TokenGenerator,
|
||||||
|
_ bool,
|
||||||
|
_ *logging.ConnectionTracer,
|
||||||
|
_ uint64,
|
||||||
|
_ utils.Logger,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
) quicConn {
|
||||||
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
|
conn.EXPECT().handlePacket(gomock.Any())
|
||||||
|
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
|
||||||
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||||
|
conn.EXPECT().run().MaxTimes(1)
|
||||||
|
conn.EXPECT().Context().Return(context.Background())
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
phm.EXPECT().Get(gomock.Any())
|
||||||
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||||
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
_, ok := fn()
|
||||||
|
return ok
|
||||||
|
})
|
||||||
|
serv.handleInitialImpl(
|
||||||
|
receivedPacket{buffer: getPacketBuffer()},
|
||||||
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||||
|
)
|
||||||
|
Eventually(destroyed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns when the context is canceled", func() {
|
It("returns when the context is canceled", func() {
|
||||||
|
@ -1343,10 +1388,7 @@ var _ = Describe("Server", func() {
|
||||||
serv.connHandler = phm
|
serv.connHandler = phm
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() { tr.Close() })
|
||||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
|
||||||
tr.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
It("passes packets to existing connections", func() {
|
It("passes packets to existing connections", func() {
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||||
|
@ -1425,6 +1467,8 @@ var _ = Describe("Server", func() {
|
||||||
conn.EXPECT().earlyConnReady()
|
conn.EXPECT().earlyConnReady()
|
||||||
conn.EXPECT().Context().Return(context.Background())
|
conn.EXPECT().Context().Return(context.Background())
|
||||||
close(called)
|
close(called)
|
||||||
|
// shutdown
|
||||||
|
conn.EXPECT().destroy(gomock.Any())
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -275,7 +275,8 @@ func (t *Transport) runSendQueue() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying connection and waits until listen has returned.
|
// Close closes the underlying connection.
|
||||||
|
// If any listener was started, it will be closed as well.
|
||||||
// It is invalid to start new listeners or connections after that.
|
// It is invalid to start new listeners or connections after that.
|
||||||
func (t *Transport) Close() error {
|
func (t *Transport) Close() error {
|
||||||
t.close(errors.New("closing"))
|
t.close(errors.New("closing"))
|
||||||
|
@ -294,7 +295,6 @@ func (t *Transport) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) closeServer() {
|
func (t *Transport) closeServer() {
|
||||||
t.handlerMap.CloseServer()
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
t.server = nil
|
t.server = nil
|
||||||
if t.isSingleUse {
|
if t.isSingleUse {
|
||||||
|
@ -322,7 +322,7 @@ func (t *Transport) close(e error) {
|
||||||
t.handlerMap.Close(e)
|
t.handlerMap.Close(e)
|
||||||
}
|
}
|
||||||
if t.server != nil {
|
if t.server != nil {
|
||||||
t.server.setCloseError(e)
|
t.server.close(e, false)
|
||||||
}
|
}
|
||||||
t.closed = true
|
t.closed = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,7 +114,6 @@ var _ = Describe("Transport", func() {
|
||||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||||
tr.handlerMap = phm
|
tr.handlerMap = phm
|
||||||
|
|
||||||
phm.EXPECT().CloseServer()
|
|
||||||
Expect(ln.Close()).To(Succeed())
|
Expect(ln.Close()).To(Succeed())
|
||||||
|
|
||||||
// shutdown
|
// shutdown
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue