mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
add a GetConfigForClient callback to the Config
This commit is contained in:
parent
ba942715db
commit
bc7cb706c5
9 changed files with 223 additions and 47 deletions
|
@ -103,6 +103,7 @@ func populateConfig(config *Config) *Config {
|
|||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
|
@ -45,7 +46,7 @@ var _ = Describe("Config", func() {
|
|||
}
|
||||
|
||||
switch fn := typ.Field(i).Name; fn {
|
||||
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
|
||||
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
|
||||
// Can't compare functions.
|
||||
case "Versions":
|
||||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||
|
@ -108,6 +109,7 @@ var _ = Describe("Config", func() {
|
|||
It("clones function fields", func() {
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
|
||||
c1 := &Config{
|
||||
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
|
||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||
}
|
||||
|
@ -116,6 +118,8 @@ var _ = Describe("Config", func() {
|
|||
Expect(calledAddrValidation).To(BeTrue())
|
||||
c2.AllowConnectionWindowIncrease(nil, 1234)
|
||||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
|
||||
Expect(err).To(MatchError("nope"))
|
||||
})
|
||||
|
||||
It("clones non-function fields", func() {
|
||||
|
@ -164,6 +168,7 @@ var _ = Describe("Config", func() {
|
|||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
|
||||
Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
|
||||
Expect(c.DisablePathMTUDiscovery).To(BeFalse())
|
||||
Expect(c.GetConfigForClient).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates empty fields with default values, for the server", func() {
|
||||
|
|
|
@ -436,6 +436,72 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("GetConfigForClient", func() {
|
||||
It("uses the quic.Config returned by GetConfigForClient", func() {
|
||||
serverConfig.EnableDatagrams = false
|
||||
var calledFrom net.Addr
|
||||
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
conf := serverConfig.Clone()
|
||||
conf.EnableDatagrams = true
|
||||
calledFrom = info.RemoteAddr
|
||||
return getQuicConfig(conf), nil
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
cs := conn.ConnectionState()
|
||||
Expect(cs.SupportsDatagrams).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
|
||||
})
|
||||
|
||||
It("rejects the connection attempt if GetConfigForClient errors", func() {
|
||||
serverConfig.EnableDatagrams = false
|
||||
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
return nil, errors.New("rejected")
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept(context.Background())
|
||||
Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
|
||||
})
|
||||
})
|
||||
|
||||
It("doesn't send any packets when generating the ClientHello fails", func() {
|
||||
ln, err := net.ListenUDP("udp", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -239,6 +239,9 @@ type ConnectionIDGenerator interface {
|
|||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// GetConfigForClient is called for incoming connections.
|
||||
// If the error is not nil, the connection attempt is refused.
|
||||
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
|
@ -324,6 +327,10 @@ type Config struct {
|
|||
Tracer logging.Tracer
|
||||
}
|
||||
|
||||
type ClientHelloInfo struct {
|
||||
RemoteAddr net.Addr
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection
|
||||
type ConnectionState struct {
|
||||
TLS handshake.ConnectionState
|
||||
|
|
|
@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
|
|||
}
|
||||
|
||||
// AddWithConnID mocks base method.
|
||||
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
|
||||
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(bool)
|
||||
|
|
|
@ -122,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
|||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
|
@ -130,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
|
|||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
conn := fn()
|
||||
conn, ok := fn()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
h.handlers[clientDestConnID] = conn
|
||||
h.handlers[newConnID] = conn
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
|
|
|
@ -62,14 +62,14 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
var called bool
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.AddWithConnID(connID1, connID2, func() packetHandler {
|
||||
Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
|
||||
called = true
|
||||
return NewMockPacketHandler(mockCtrl)
|
||||
return NewMockPacketHandler(mockCtrl), true
|
||||
})).To(BeTrue())
|
||||
Expect(called).To(BeTrue())
|
||||
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler {
|
||||
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
|
||||
Fail("didn't expect the constructor to be executed")
|
||||
return nil
|
||||
return nil, false
|
||||
})).To(BeFalse())
|
||||
})
|
||||
|
||||
|
|
25
server.go
25
server.go
|
@ -33,7 +33,7 @@ type packetHandler interface {
|
|||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
||||
Close(error)
|
||||
CloseServer()
|
||||
connRunner
|
||||
|
@ -584,7 +584,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
|
||||
var tracer logging.ConnectionTracer
|
||||
if s.config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
|
@ -598,6 +598,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
connID,
|
||||
)
|
||||
}
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
return nil, false
|
||||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info),
|
||||
s.connHandler,
|
||||
|
@ -608,7 +617,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
s.config,
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
clientAddrIsValid,
|
||||
|
@ -626,10 +635,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
|||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
return conn
|
||||
return conn, true
|
||||
}); !added {
|
||||
// TODO: don't just drop the packet
|
||||
// Properly reject the connection attempt.
|
||||
go func() {
|
||||
defer p.buffer.Release()
|
||||
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
|
||||
s.logger.Debugf("Error rejecting connection: %s", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
go conn.run()
|
||||
|
|
147
server_test.go
147
server_test.go
|
@ -267,14 +267,14 @@ var _ = Describe("Server", func() {
|
|||
var newConnID protocol.ConnectionID
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
@ -468,14 +468,14 @@ var _ = Describe("Server", func() {
|
|||
var newConnID protocol.ConnectionID
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
@ -532,10 +532,10 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).AnyTimes()
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||
|
||||
|
@ -594,7 +594,6 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("only creates a single connection for a duplicate Initial", func() {
|
||||
var createdConn bool
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
|
@ -615,15 +614,19 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
createdConn = true
|
||||
return conn
|
||||
return NewMockQUICConn(mockCtrl)
|
||||
}
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Expect(createdConn).To(BeFalse())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
|
@ -657,10 +660,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
|
@ -729,10 +732,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
|
||||
|
@ -792,7 +795,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -994,6 +997,84 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the config returned by GetConfigClient", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
conf := &Config{MaxIncomingStreams: 1234}
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).To(Equal(conn))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
handshakeChan := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
conn.EXPECT().run().Do(func() {})
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(handshakeChan) // complete the handshake
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects a connection attempt when GetConfigClient returns an error", func() {
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
rejectHdr := parseHeader(b)
|
||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
|
||||
)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts new connections when the handshake completes", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
|
@ -1033,10 +1114,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
serv.handleInitialImpl(
|
||||
|
@ -1107,10 +1188,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
|
@ -1154,10 +1235,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
||||
|
@ -1216,10 +1297,10 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -1346,10 +1427,10 @@ var _ = Describe("Server", func() {
|
|||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue