mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
change the source connection ID when creating a new IETF QUIC session
This commit is contained in:
parent
8bfb29f06e
commit
d7dee33bc7
5 changed files with 35 additions and 51 deletions
|
@ -331,7 +331,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
|||
}
|
||||
|
||||
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
|
||||
// TODO(#1003): add support for server-chosen connection IDs
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
|
||||
|
|
11
server.go
11
server.go
|
@ -165,15 +165,10 @@ func (s *server) setupTLS() error {
|
|||
case <-s.errorChan:
|
||||
return
|
||||
case tlsSession := <-sessionChan:
|
||||
connID := tlsSession.connID
|
||||
sess := tlsSession.sess
|
||||
if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists
|
||||
continue
|
||||
}
|
||||
// TODO(#1003): There's a race condition here.
|
||||
// If another connection with the same conn ID is added between Get() and Add(), it would be overwritten.
|
||||
// We can avoid this be using server-chosen connection IDs.
|
||||
s.sessionHandler.Add(connID, sess)
|
||||
// The connection ID is a randomly chosen 8 byte value.
|
||||
// It is safe to assume that it doesn't collide with other randomly chosen values.
|
||||
s.sessionHandler.Add(tlsSession.connID, sess)
|
||||
go sess.run()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,7 +158,6 @@ var _ = Describe("Server", func() {
|
|||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sessionHandler.EXPECT().Get(connID)
|
||||
sessionHandler.EXPECT().Add(connID, sess)
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
connID: connID,
|
||||
|
@ -167,24 +166,6 @@ var _ = Describe("Server", func() {
|
|||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("only accepts one new TLS sessions for one connection ID", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
sess := NewMockPacketHandler(mockCtrl)
|
||||
err := serv.setupTLS()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) {
|
||||
close(done)
|
||||
})
|
||||
// don't EXPECT any calls to sessionHandler.Add
|
||||
serv.serverTLS.sessionChan <- tlsSession{
|
||||
connID: connID,
|
||||
sess: sess,
|
||||
}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts a session once the connection it is forward secure", func() {
|
||||
s := NewMockPacketHandler(mockCtrl)
|
||||
s.EXPECT().handlePacket(gomock.Any())
|
||||
|
|
|
@ -92,7 +92,7 @@ func newServerTLS(
|
|||
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
|
||||
// TODO: add a check that DestConnID == SrcConnID
|
||||
s.logger.Debugf("Received a Packet. Handling it statelessly.")
|
||||
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
|
||||
sess, connID, err := s.handleInitialImpl(remoteAddr, hdr, data)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||
return
|
||||
|
@ -101,7 +101,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
|
|||
return
|
||||
}
|
||||
s.sessionChan <- tlsSession{
|
||||
connID: hdr.DestConnectionID,
|
||||
connID: connID,
|
||||
sess: sess,
|
||||
}
|
||||
}
|
||||
|
@ -135,48 +135,48 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
|
||||
if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
|
||||
return nil, errors.New("dropping too small Initial packet")
|
||||
return nil, nil, errors.New("dropping too small Initial packet")
|
||||
}
|
||||
// check version, if not matching send VNP
|
||||
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
|
||||
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
||||
vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
_, err = s.conn.WriteTo(vnp, remoteAddr)
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// unpack packet and check stream frame contents
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version)
|
||||
if err != nil {
|
||||
s.logger.Debugf("Error unpacking initial packet: %s", err)
|
||||
return nil, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
|
||||
sess, connID, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
|
||||
if err != nil {
|
||||
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
|
||||
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr)
|
||||
}
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return sess, connID, nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) {
|
||||
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
|
||||
version := hdr.Version
|
||||
bc := handshake.NewCryptoStreamConn(remoteAddr)
|
||||
bc.AddDataForReading(frame.Data)
|
||||
tls, paramsChan, err := s.newMintConn(bc, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
alert := tls.Handshake()
|
||||
if alert == mint.AlertStatelessRetry {
|
||||
|
@ -197,29 +197,34 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
|||
}
|
||||
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
_, err = s.conn.WriteTo(data, remoteAddr)
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if alert != mint.AlertNoAlert {
|
||||
return nil, alert
|
||||
return nil, nil, alert
|
||||
}
|
||||
if tls.State() != mint.StateServerNegotiated {
|
||||
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
|
||||
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
|
||||
}
|
||||
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return nil, alert
|
||||
return nil, nil, alert
|
||||
}
|
||||
if tls.State() != mint.StateServerWaitFlight2 {
|
||||
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
|
||||
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
|
||||
}
|
||||
params := <-paramsChan
|
||||
connID, err := protocol.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
s.logger.Debugf("Changing source connection ID to %s.", connID)
|
||||
sess, err := newTLSServerSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
s.sessionRunner,
|
||||
hdr.SrcConnectionID,
|
||||
hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here
|
||||
connID,
|
||||
protocol.PacketNumber(1), // TODO: use a random packet number here
|
||||
s.config,
|
||||
tls,
|
||||
|
@ -230,10 +235,10 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
|||
s.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
cs := sess.getCryptoStream()
|
||||
cs.setReadOffset(frame.DataLen())
|
||||
bc.SetStream(cs)
|
||||
return sess, nil
|
||||
return sess, connID, nil
|
||||
}
|
||||
|
|
|
@ -146,7 +146,11 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
Expect(conn.dataWritten.Len()).To(BeZero())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(sessionChan).Should(Receive())
|
||||
var tlsSess tlsSession
|
||||
Eventually(sessionChan).Should(Receive(&tlsSess))
|
||||
// make sure we're using a server-generated connection ID
|
||||
Expect(tlsSess.connID).ToNot(Equal(hdr.SrcConnectionID))
|
||||
Expect(tlsSess.connID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue