diff --git a/interface.go b/interface.go index 2567144f..5ba909cb 100644 --- a/interface.go +++ b/interface.go @@ -17,7 +17,10 @@ type StreamID = protocol.StreamID type VersionNumber = protocol.VersionNumber // A Cookie can be used to verify the ownership of the client address. -type Cookie = handshake.Cookie +type Cookie struct { + RemoteAddr string + SentTime time.Time +} // ConnectionState records basic details about the QUIC connection. type ConnectionState = handshake.ConnectionState diff --git a/internal/handshake/cookie_generator.go b/internal/handshake/cookie_generator.go index 00f6e7ef..6d1288ed 100644 --- a/internal/handshake/cookie_generator.go +++ b/internal/handshake/cookie_generator.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" ) const ( @@ -14,14 +16,17 @@ const ( // A Cookie is derived from the client address and can be used to verify the ownership of this address. type Cookie struct { - RemoteAddr string - // The time that the STK was issued (resolution 1 second) + RemoteAddr string + OriginalDestConnectionID protocol.ConnectionID + // The time that the Cookie was issued (resolution 1 second) SentTime time.Time } // token is the struct that is used for ASN1 serialization and deserialization type token struct { - Data []byte + RemoteAddr []byte + OriginalDestConnectionID []byte + Timestamp int64 } @@ -42,10 +47,11 @@ func NewCookieGenerator() (*CookieGenerator, error) { } // NewToken generates a new Cookie for a given source address -func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { +func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) { data, err := asn1.Marshal(token{ - Data: encodeRemoteAddr(raddr), - Timestamp: time.Now().Unix(), + RemoteAddr: encodeRemoteAddr(raddr), + OriginalDestConnectionID: origConnID, + Timestamp: time.Now().Unix(), }) if err != nil { return nil, err @@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { if len(rest) != 0 { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } - return &Cookie{ - RemoteAddr: decodeRemoteAddr(t.Data), + cookie := &Cookie{ + RemoteAddr: decodeRemoteAddr(t.RemoteAddr), SentTime: time.Unix(t.Timestamp, 0), - }, nil + } + if len(t.OriginalDestConnectionID) > 0 { + cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + } + return cookie, nil } // encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie diff --git a/internal/handshake/cookie_generator_test.go b/internal/handshake/cookie_generator_test.go index f0480701..6c0a27b0 100644 --- a/internal/handshake/cookie_generator_test.go +++ b/internal/handshake/cookie_generator_test.go @@ -5,6 +5,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -20,7 +22,7 @@ var _ = Describe("Cookie Generator", func() { It("generates a Cookie", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}, nil) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) @@ -33,7 +35,10 @@ var _ = Describe("Cookie Generator", func() { It("accepts a valid cookie", func() { ip := net.IPv4(192, 168, 0, 1) - token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken( + &net.UDPAddr{IP: ip, Port: 1337}, + nil, + ) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) @@ -41,6 +46,18 @@ var _ = Describe("Cookie Generator", func() { // the time resolution of the Cookie is just 1 second // if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) + Expect(cookie.OriginalDestConnectionID).To(BeNil()) + }) + + It("saves the connection ID", func() { + token, err := cookieGen.NewToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ) + Expect(err).ToNot(HaveOccurred()) + cookie, err := cookieGen.DecodeToken(token) + Expect(err).ToNot(HaveOccurred()) + Expect(cookie.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) }) It("rejects invalid tokens", func() { @@ -56,7 +73,7 @@ var _ = Describe("Cookie Generator", func() { }) It("rejects tokens that can be decoded, but have additional payload", func() { - t, err := asn1.Marshal(token{Data: []byte("foobar")}) + t, err := asn1.Marshal(token{RemoteAddr: []byte("foobar")}) Expect(err).ToNot(HaveOccurred()) t = append(t, []byte("rest")...) enc, err := cookieGen.cookieProtector.NewToken(t) @@ -67,7 +84,7 @@ var _ = Describe("Cookie Generator", func() { // we don't generate tokens that have no data, but we should be able to handle them if we receive one for whatever reason It("doesn't panic if a tokens has no data", func() { - t, err := asn1.Marshal(token{Data: []byte("")}) + t, err := asn1.Marshal(token{RemoteAddr: []byte("")}) Expect(err).ToNot(HaveOccurred()) enc, err := cookieGen.cookieProtector.NewToken(t) Expect(err).ToNot(HaveOccurred()) @@ -86,7 +103,7 @@ var _ = Describe("Cookie Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - token, err := cookieGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) @@ -99,7 +116,7 @@ var _ = Describe("Cookie Generator", func() { It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := cookieGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) diff --git a/server.go b/server.go index 38b4a392..6a4f8b94 100644 --- a/server.go +++ b/server.go @@ -340,11 +340,16 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con return nil, nil, errors.New("dropping too small Initial packet") } - var cookie *handshake.Cookie + var cookie *Cookie + var origDestConnectionID protocol.ConnectionID if len(hdr.Token) > 0 { c, err := s.cookieGenerator.DecodeToken(hdr.Token) if err == nil { - cookie = c + cookie = &Cookie{ + RemoteAddr: c.RemoteAddr, + SentTime: c.SentTime, + } + origDestConnectionID = c.OriginalDestConnectionID } } if !s.config.AcceptCookie(p.remoteAddr, cookie) { @@ -359,7 +364,14 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con return nil, nil, err } s.logger.Debugf("Changing connection ID to %s.", connID) - sess, err := s.createNewSession(p.remoteAddr, hdr.DestConnectionID, hdr.SrcConnectionID, connID, hdr.Version) + sess, err := s.createNewSession( + p.remoteAddr, + origDestConnectionID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + hdr.Version, + ) if err != nil { return nil, nil, err } @@ -369,7 +381,8 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con func (s *server) createNewSession( remoteAddr net.Addr, - origConnID protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, version protocol.VersionNumber, @@ -384,12 +397,13 @@ func (s *server) createNewSession( MaxUniStreams: uint64(s.config.MaxIncomingUniStreams), DisableMigration: true, // TODO(#855): generate a real token - StatelessResetToken: bytes.Repeat([]byte{42}, 16), + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + OriginalConnectionID: origDestConnID, } sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionRunner, - origConnID, + clientDestConnID, destConnID, srcConnID, s.config, @@ -406,7 +420,7 @@ func (s *server) createNewSession( } func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { - token, err := s.cookieGenerator.NewToken(remoteAddr) + token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID) if err != nil { return err } diff --git a/server_test.go b/server_test.go index 7d57be38..d63ba735 100644 --- a/server_test.go +++ b/server_test.go @@ -165,7 +165,7 @@ var _ = Describe("Server", func() { close(done) return false } - token, err := serv.cookieGenerator.NewToken(raddr) + token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) serv.handlePacket(&receivedPacket{ remoteAddr: raddr, @@ -222,7 +222,7 @@ var _ = Describe("Server", func() { }) It("replies with a Retry packet, if a Cookie is required", func() { - serv.config.AcceptCookie = func(_ net.Addr, _ *handshake.Cookie) bool { return false } + serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, @@ -244,7 +244,7 @@ var _ = Describe("Server", func() { }) It("creates a session, if no Cookie is required", func() { - serv.config.AcceptCookie = func(_ net.Addr, _ *handshake.Cookie) bool { return true } + serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, @@ -358,7 +358,7 @@ var _ = Describe("Server", func() { sess.EXPECT().run().Do(func() {}) return sess, nil } - _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, protocol.VersionWhatever) + _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) close(completeHandshake) diff --git a/session.go b/session.go index b04dd179..cea7cc9d 100644 --- a/session.go +++ b/session.go @@ -135,7 +135,7 @@ var _ streamSender = &session{} var newSession = func( conn connection, runner sessionRunner, - origConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -163,7 +163,7 @@ var newSession = func( cs, err := handshake.NewCryptoSetupServer( initialStream, handshakeStream, - origConnID, + clientDestConnID, params, s.processTransportParameters, tlsConf,