change the source connection ID when creating a new IETF QUIC session

This commit is contained in:
Marten Seemann 2018-05-17 13:33:02 +09:00
parent 8bfb29f06e
commit d7dee33bc7
5 changed files with 35 additions and 51 deletions

View file

@ -331,7 +331,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
}
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
// TODO(#1003): add support for server-chosen connection IDs
// reject packets with the wrong connection ID
if !hdr.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)

View file

@ -165,15 +165,10 @@ func (s *server) setupTLS() error {
case <-s.errorChan:
return
case tlsSession := <-sessionChan:
connID := tlsSession.connID
sess := tlsSession.sess
if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists
continue
}
// TODO(#1003): There's a race condition here.
// If another connection with the same conn ID is added between Get() and Add(), it would be overwritten.
// We can avoid this be using server-chosen connection IDs.
s.sessionHandler.Add(connID, sess)
// The connection ID is a randomly chosen 8 byte value.
// It is safe to assume that it doesn't collide with other randomly chosen values.
s.sessionHandler.Add(tlsSession.connID, sess)
go sess.run()
}
}

View file

@ -158,7 +158,6 @@ var _ = Describe("Server", func() {
sess.EXPECT().run().Do(func() { close(run) })
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
sessionHandler.EXPECT().Get(connID)
sessionHandler.EXPECT().Add(connID, sess)
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
@ -167,24 +166,6 @@ var _ = Describe("Server", func() {
Eventually(run).Should(BeClosed())
})
It("only accepts one new TLS sessions for one connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
sess := NewMockPacketHandler(mockCtrl)
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) {
close(done)
})
// don't EXPECT any calls to sessionHandler.Add
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
sess: sess,
}
Eventually(done).Should(BeClosed())
})
It("accepts a session once the connection it is forward secure", func() {
s := NewMockPacketHandler(mockCtrl)
s.EXPECT().handlePacket(gomock.Any())

View file

@ -92,7 +92,7 @@ func newServerTLS(
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
// TODO: add a check that DestConnID == SrcConnID
s.logger.Debugf("Received a Packet. Handling it statelessly.")
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
sess, connID, err := s.handleInitialImpl(remoteAddr, hdr, data)
if err != nil {
s.logger.Errorf("Error occurred handling initial packet: %s", err)
return
@ -101,7 +101,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
return
}
s.sessionChan <- tlsSession{
connID: hdr.DestConnectionID,
connID: connID,
sess: sess,
}
}
@ -135,48 +135,48 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
return err
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
return nil, errors.New("dropping too small Initial packet")
return nil, nil, errors.New("dropping too small Initial packet")
}
// check version, if not matching send VNP
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions)
if err != nil {
return nil, err
return nil, nil, err
}
_, err = s.conn.WriteTo(vnp, remoteAddr)
return nil, err
return nil, nil, err
}
// unpack packet and check stream frame contents
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS)
if err != nil {
return nil, err
return nil, nil, err
}
frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version)
if err != nil {
s.logger.Debugf("Error unpacking initial packet: %s", err)
return nil, nil
return nil, nil, nil
}
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
sess, connID, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
if err != nil {
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr)
}
return nil, err
return nil, nil, err
}
return sess, nil
return sess, connID, nil
}
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) {
func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
version := hdr.Version
bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data)
tls, paramsChan, err := s.newMintConn(bc, version)
if err != nil {
return nil, err
return nil, nil, err
}
alert := tls.Handshake()
if alert == mint.AlertStatelessRetry {
@ -197,29 +197,34 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
}
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)
if err != nil {
return nil, err
return nil, nil, err
}
_, err = s.conn.WriteTo(data, remoteAddr)
return nil, err
return nil, nil, err
}
if alert != mint.AlertNoAlert {
return nil, alert
return nil, nil, alert
}
if tls.State() != mint.StateServerNegotiated {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
}
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
return nil, alert
return nil, nil, alert
}
if tls.State() != mint.StateServerWaitFlight2 {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
connID, err := protocol.GenerateConnectionID()
if err != nil {
return nil, nil, err
}
s.logger.Debugf("Changing source connection ID to %s.", connID)
sess, err := newTLSServerSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
hdr.SrcConnectionID,
hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here
connID,
protocol.PacketNumber(1), // TODO: use a random packet number here
s.config,
tls,
@ -230,10 +235,10 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
s.logger,
)
if err != nil {
return nil, err
return nil, nil, err
}
cs := sess.getCryptoStream()
cs.setReadOffset(frame.DataLen())
bc.SetStream(cs)
return sess, nil
return sess, connID, nil
}

View file

@ -146,7 +146,11 @@ var _ = Describe("Stateless TLS handling", func() {
Expect(conn.dataWritten.Len()).To(BeZero())
close(done)
}()
Eventually(sessionChan).Should(Receive())
var tlsSess tlsSession
Eventually(sessionChan).Should(Receive(&tlsSess))
// make sure we're using a server-generated connection ID
Expect(tlsSess.connID).ToNot(Equal(hdr.SrcConnectionID))
Expect(tlsSess.connID).ToNot(Equal(hdr.DestConnectionID))
Eventually(done).Should(BeClosed())
})