mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
add a GetConfigForClient callback to the Config
This commit is contained in:
parent
ba942715db
commit
bc7cb706c5
9 changed files with 223 additions and 47 deletions
147
server_test.go
147
server_test.go
|
@ -267,14 +267,14 @@ var _ = Describe("Server", func() {
|
|||
var newConnID protocol.ConnectionID
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
@ -468,14 +468,14 @@ var _ = Describe("Server", func() {
|
|||
var newConnID protocol.ConnectionID
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
@ -532,10 +532,10 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).AnyTimes()
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||
|
||||
|
@ -594,7 +594,6 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("only creates a single connection for a duplicate Initial", func() {
|
||||
var createdConn bool
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
|
@ -615,15 +614,19 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
createdConn = true
|
||||
return conn
|
||||
return NewMockQUICConn(mockCtrl)
|
||||
}
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Expect(createdConn).To(BeFalse())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
|
@ -657,10 +660,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
|
@ -729,10 +732,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
|
||||
|
@ -792,7 +795,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -994,6 +997,84 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the config returned by GetConfigClient", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
conf := &Config{MaxIncomingStreams: 1234}
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).To(Equal(conn))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
handshakeChan := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
conn.EXPECT().run().Do(func() {})
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(handshakeChan) // complete the handshake
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects a connection attempt when GetConfigClient returns an error", func() {
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
rejectHdr := parseHeader(b)
|
||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
|
||||
)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts new connections when the handshake completes", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
|
@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
serv.handleInitialImpl(
|
||||
|
@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
||||
|
@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() {
|
|||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue