add a GetConfigForClient callback to the Config

This commit is contained in:
Marten Seemann 2023-04-25 11:31:01 +02:00
parent ba942715db
commit bc7cb706c5
9 changed files with 223 additions and 47 deletions

View file

@ -267,14 +267,14 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
conn := NewMockQUICConn(mockCtrl)
@ -468,14 +468,14 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
@ -532,10 +532,10 @@ var _ = Describe("Server", func() {
It("drops packets if the receive queue is full", func() {
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
@ -594,7 +594,6 @@ var _ = Describe("Server", func() {
It("only creates a single connection for a duplicate Initial", func() {
var createdConn bool
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
runner connRunner,
@ -615,15 +614,19 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicConn {
createdConn = true
return conn
return NewMockQUICConn(mockCtrl)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
@ -657,10 +660,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
@ -729,10 +732,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
@ -792,7 +795,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -994,6 +997,84 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("uses the config returned by GetConfigClient", func() {
conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
_, ok := fn()
return ok
})
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
return len(b), nil
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
)
Eventually(done).Should(BeClosed())
})
It("accepts new connections when the handshake completes", func() {
conn := NewMockQUICConn(mockCtrl)
@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl(
@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() {
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())