diff --git a/client.go b/client.go index 3c6aefb3..5fbeb856 100644 --- a/client.go +++ b/client.go @@ -331,7 +331,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { } func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { - // TODO(#1003): add support for server-chosen connection IDs // reject packets with the wrong connection ID if !hdr.DestConnectionID.Equal(c.srcConnID) { return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) diff --git a/server.go b/server.go index 963801b7..ea0ab833 100644 --- a/server.go +++ b/server.go @@ -165,15 +165,10 @@ func (s *server) setupTLS() error { case <-s.errorChan: return case tlsSession := <-sessionChan: - connID := tlsSession.connID sess := tlsSession.sess - if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists - continue - } - // TODO(#1003): There's a race condition here. - // If another connection with the same conn ID is added between Get() and Add(), it would be overwritten. - // We can avoid this be using server-chosen connection IDs. - s.sessionHandler.Add(connID, sess) + // The connection ID is a randomly chosen 8 byte value. + // It is safe to assume that it doesn't collide with other randomly chosen values. + s.sessionHandler.Add(tlsSession.connID, sess) go sess.run() } } diff --git a/server_test.go b/server_test.go index 8e2df4e4..9f214e6b 100644 --- a/server_test.go +++ b/server_test.go @@ -158,7 +158,6 @@ var _ = Describe("Server", func() { sess.EXPECT().run().Do(func() { close(run) }) err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) - sessionHandler.EXPECT().Get(connID) sessionHandler.EXPECT().Add(connID, sess) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, @@ -167,24 +166,6 @@ var _ = Describe("Server", func() { Eventually(run).Should(BeClosed()) }) - It("only accepts one new TLS sessions for one connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess := NewMockPacketHandler(mockCtrl) - err := serv.setupTLS() - Expect(err).ToNot(HaveOccurred()) - - done := make(chan struct{}) - sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) { - close(done) - }) - // don't EXPECT any calls to sessionHandler.Add - serv.serverTLS.sessionChan <- tlsSession{ - connID: connID, - sess: sess, - } - Eventually(done).Should(BeClosed()) - }) - It("accepts a session once the connection it is forward secure", func() { s := NewMockPacketHandler(mockCtrl) s.EXPECT().handlePacket(gomock.Any()) diff --git a/server_tls.go b/server_tls.go index 5d1f4ea1..fecdf523 100644 --- a/server_tls.go +++ b/server_tls.go @@ -92,7 +92,7 @@ func newServerTLS( func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { // TODO: add a check that DestConnID == SrcConnID s.logger.Debugf("Received a Packet. Handling it statelessly.") - sess, err := s.handleInitialImpl(remoteAddr, hdr, data) + sess, connID, err := s.handleInitialImpl(remoteAddr, hdr, data) if err != nil { s.logger.Errorf("Error occurred handling initial packet: %s", err) return @@ -101,7 +101,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] return } s.sessionChan <- tlsSession{ - connID: hdr.DestConnectionID, + connID: connID, sess: sess, } } @@ -135,48 +135,48 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea return err } -func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { +func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) { if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { - return nil, errors.New("dropping too small Initial packet") + return nil, nil, errors.New("dropping too small Initial packet") } // check version, if not matching send VNP if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions) if err != nil { - return nil, err + return nil, nil, err } _, err = s.conn.WriteTo(vnp, remoteAddr) - return nil, err + return nil, nil, err } // unpack packet and check stream frame contents aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS) if err != nil { - return nil, err + return nil, nil, err } frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version) if err != nil { s.logger.Debugf("Error unpacking initial packet: %s", err) - return nil, nil + return nil, nil, nil } - sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) + sess, connID, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) if err != nil { if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) } - return nil, err + return nil, nil, err } - return sess, nil + return sess, connID, nil } -func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) { version := hdr.Version bc := handshake.NewCryptoStreamConn(remoteAddr) bc.AddDataForReading(frame.Data) tls, paramsChan, err := s.newMintConn(bc, version) if err != nil { - return nil, err + return nil, nil, err } alert := tls.Handshake() if alert == mint.AlertStatelessRetry { @@ -197,29 +197,34 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, } data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger) if err != nil { - return nil, err + return nil, nil, err } _, err = s.conn.WriteTo(data, remoteAddr) - return nil, err + return nil, nil, err } if alert != mint.AlertNoAlert { - return nil, alert + return nil, nil, alert } if tls.State() != mint.StateServerNegotiated { - return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) + return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State()) } if alert := tls.Handshake(); alert != mint.AlertNoAlert { - return nil, alert + return nil, nil, alert } if tls.State() != mint.StateServerWaitFlight2 { - return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) + return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State()) } params := <-paramsChan + connID, err := protocol.GenerateConnectionID() + if err != nil { + return nil, nil, err + } + s.logger.Debugf("Changing source connection ID to %s.", connID) sess, err := newTLSServerSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionRunner, hdr.SrcConnectionID, - hdr.DestConnectionID, // TODO(#1003): we can use a server-chosen connection ID here + connID, protocol.PacketNumber(1), // TODO: use a random packet number here s.config, tls, @@ -230,10 +235,10 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, s.logger, ) if err != nil { - return nil, err + return nil, nil, err } cs := sess.getCryptoStream() cs.setReadOffset(frame.DataLen()) bc.SetStream(cs) - return sess, nil + return sess, connID, nil } diff --git a/server_tls_test.go b/server_tls_test.go index 4d4e764d..03122aca 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -146,7 +146,11 @@ var _ = Describe("Stateless TLS handling", func() { Expect(conn.dataWritten.Len()).To(BeZero()) close(done) }() - Eventually(sessionChan).Should(Receive()) + var tlsSess tlsSession + Eventually(sessionChan).Should(Receive(&tlsSess)) + // make sure we're using a server-generated connection ID + Expect(tlsSess.connID).ToNot(Equal(hdr.SrcConnectionID)) + Expect(tlsSess.connID).ToNot(Equal(hdr.DestConnectionID)) Eventually(done).Should(BeClosed()) })