diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index 7c36e47b..af8c3f1e 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -21,6 +21,7 @@ type Token struct { SentTime time.Time // only set for retry tokens OriginalDestConnectionID protocol.ConnectionID + RetrySrcConnectionID protocol.ConnectionID } // token is the struct that is used for ASN1 serialization and deserialization @@ -29,6 +30,7 @@ type token struct { RemoteAddr []byte Timestamp int64 OriginalDestConnectionID []byte + RetrySrcConnectionID []byte } // A TokenGenerator generates tokens @@ -48,11 +50,16 @@ func NewTokenGenerator() (*TokenGenerator, error) { } // NewRetryToken generates a new token for a Retry for a given source address -func (g *TokenGenerator) NewRetryToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) { +func (g *TokenGenerator) NewRetryToken( + raddr net.Addr, + origDestConnID protocol.ConnectionID, + retrySrcConnID protocol.ConnectionID, +) ([]byte, error) { data, err := asn1.Marshal(token{ IsRetryToken: true, RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origConnID, + OriginalDestConnectionID: origDestConnID, + RetrySrcConnectionID: retrySrcConnID, Timestamp: time.Now().UnixNano(), }) if err != nil { @@ -97,8 +104,9 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { RemoteAddr: decodeRemoteAddr(t.RemoteAddr), SentTime: time.Unix(0, t.Timestamp), } - if len(t.OriginalDestConnectionID) > 0 { + if t.IsRetryToken { token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) } return token, nil } diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 270d7ac5..91cba354 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -22,7 +22,7 @@ var _ = Describe("Token Generator", func() { It("generates a token", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil) + token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) @@ -38,24 +38,28 @@ var _ = Describe("Token Generator", func() { tokenEnc, err := tokenGen.NewRetryToken( &net.UDPAddr{IP: ip, Port: 1337}, nil, + nil, ) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) Expect(token.RemoteAddr).To(Equal("192.168.0.1")) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - Expect(token.OriginalDestConnectionID).To(BeNil()) + Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) + Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) }) It("saves the connection ID", func() { tokenEnc, err := tokenGen.NewRetryToken( &net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, ) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) }) It("rejects invalid tokens", func() { @@ -101,7 +105,7 @@ var _ = Describe("Token Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) @@ -112,7 +116,7 @@ var _ = Describe("Token Generator", func() { It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) diff --git a/server.go b/server.go index 5011150b..cab88dc2 100644 --- a/server.go +++ b/server.go @@ -494,11 +494,11 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID) + srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return err } - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) if err != nil { return err } @@ -506,10 +506,10 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { replyHdr.IsLongHeader = true replyHdr.Type = protocol.PacketTypeRetry replyHdr.Version = hdr.Version - replyHdr.SrcConnectionID = connID + replyHdr.SrcConnectionID = srcConnID replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token - s.logger.Debugf("Changing connection ID to %s.", connID) + s.logger.Debugf("Changing connection ID to %s.", srcConnID) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) buf := &bytes.Buffer{} diff --git a/server_test.go b/server_test.go index 003ef1ba..b466078b 100644 --- a/server_test.go +++ b/server_test.go @@ -253,7 +253,7 @@ var _ = Describe("Server", func() { close(done) return false } - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil) + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) Expect(err).ToNot(HaveOccurred()) packet := getPacket(&wire.Header{ IsLongHeader: true, @@ -291,7 +291,7 @@ var _ = Describe("Server", func() { It("creates a session when the token is accepted", func() { serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } - retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, nil) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, @@ -409,7 +409,7 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, @@ -445,7 +445,7 @@ var _ = Describe("Server", func() { It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true,