diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 728cd837..204192f3 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -455,4 +455,30 @@ var _ = Describe("Handshake tests", func() { Eventually(done).Should(BeClosed()) }) }) + + It("rejects invalid Retry token with the INVALID_TOKEN error", func() { + tokenChan := make(chan *quic.Token, 10) + serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { + tokenChan <- token + return false + } + + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + _, err = quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + nil, + ) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("INVALID_TOKEN")) + Expect(tokenChan).To(HaveLen(2)) + token := <-tokenChan + Expect(token).To(BeNil()) + token = <-tokenChan + Expect(token).ToNot(BeNil()) + Expect(token.IsRetryToken).To(BeTrue()) + }) }) diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index d9d669bf..1ce8b3af 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -22,6 +22,7 @@ const ( TransportParameterError ErrorCode = 0x8 ConnectionIDLimitError ErrorCode = 0x9 ProtocolViolation ErrorCode = 0xa + InvalidToken ErrorCode = 0xb CryptoBufferExceeded ErrorCode = 0xd ) @@ -69,6 +70,8 @@ func (e ErrorCode) String() string { return "CONNECTION_ID_LIMIT_ERROR" case ProtocolViolation: return "PROTOCOL_VIOLATION" + case InvalidToken: + return "INVALID_TOKEN" case CryptoBufferExceeded: return "CRYPTO_BUFFER_EXCEEDED" default: diff --git a/packet_unpacker.go b/packet_unpacker.go index 27eb52ec..ddf35408 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -102,7 +102,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte } func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpack(opener, hdr, data) + extHdr, parseErr := u.unpackHeader(opener, hdr, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. @@ -126,7 +126,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket( rcvTime time.Time, data []byte, ) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpack(opener, hdr, data) + extHdr, parseErr := u.unpackHeader(opener, hdr, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. @@ -144,7 +144,20 @@ func (u *packetUnpacker) unpackShortHeaderPacket( return extHdr, decrypted, nil } -func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { +func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { + extHdr, err := unpackHeader(hd, hdr, data, u.version) + if err != nil && err != wire.ErrInvalidReservedBits { + return nil, err + } + extHdr.PacketNumber = protocol.DecodePacketNumber( + extHdr.PacketNumberLen, + u.largestRcvdPacketNumber, + extHdr.PacketNumber, + ) + return extHdr, err +} + +func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) hdrLen := hdr.ParsedLen() @@ -163,7 +176,7 @@ func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byt data[hdrLen:hdrLen+4], ) // 3. parse the header (and learn the actual length of the packet number) - extHdr, parseErr := hdr.ParseExtended(r, u.version) + extHdr, parseErr := hdr.ParseExtended(r, version) if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, parseErr } @@ -171,11 +184,5 @@ func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byt if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) } - - extHdr.PacketNumber = protocol.DecodePacketNumber( - extHdr.PacketNumberLen, - u.largestRcvdPacketNumber, - extHdr.PacketNumber, - ) return extHdr, parseErr } diff --git a/server.go b/server.go index 0831dd25..7dcd9f08 100644 --- a/server.go +++ b/server.go @@ -363,6 +363,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui } if !s.config.AcceptToken(p.remoteAddr, token) { go func() { + if token != nil && token.IsRetryToken { + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + return + } if err := s.sendRetry(p.remoteAddr, hdr); err != nil { s.logger.Debugf("Error sending Retry: %s", err) } @@ -512,13 +518,39 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { return err } +func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. + sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) + data := p.data[:hdr.ParsedLen()+hdr.Length] + extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) + if err != nil { + // don't return the error here. Just drop the packet. + return nil + } + hdrLen := extHdr.ParsedLen() + if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { + // don't return the error here. Just drop the packet. + return nil + } + if s.logger.Debug() { + s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) + } + return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken) +} + func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) + return s.sendError(remoteAddr, hdr, sealer, qerr.ServerBusy) +} + +// sendError sends the error as a response to the packet received with header hdr +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode) error { packetBuffer := getPacketBuffer() defer packetBuffer.Release() buf := bytes.NewBuffer(packetBuffer.Data) - ccf := &wire.ConnectionCloseFrame{ErrorCode: qerr.ServerBusy} + ccf := &wire.ConnectionCloseFrame{ErrorCode: errorCode} replyHdr := &wire.ExtendedHeader{} replyHdr.IsLongHeader = true diff --git a/server_test.go b/server_test.go index d5fe1403..4826d3d1 100644 --- a/server_test.go +++ b/server_test.go @@ -14,6 +14,8 @@ import ( "sync/atomic" "time" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/qlog" "github.com/golang/mock/gomock" @@ -40,15 +42,27 @@ var _ = Describe("Server", func() { tlsConf *tls.Config ) - getPacket := func(hdr *wire.Header, data []byte) *receivedPacket { - buf := &bytes.Buffer{} + getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { + buffer := getPacketBuffer() + buf := bytes.NewBuffer(buffer.Data) + if hdr.IsLongHeader { + hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 + } Expect((&wire.ExtendedHeader{ Header: *hdr, - PacketNumberLen: protocol.PacketNumberLen3, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen4, }).Write(buf, protocol.VersionTLS)).To(Succeed()) + n := buf.Len() + buf.Write(p) + data := buffer.Data[:buf.Len()] + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient) + _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) + data = data[:len(data)+16] + sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) return &receivedPacket{ - data: append(buf.Bytes(), data...), - buffer: getPacketBuffer(), + data: data, + buffer: buffer, } } @@ -321,6 +335,61 @@ var _ = Describe("Server", func() { Expect(write.data[len(write.data)-16:]).To(Equal(handshake.GetRetryIntegrityTag(write.data[:len(write.data)-16], hdr.DestConnectionID)[:])) }) + It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) + var write mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&write)) + Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + replyHdr := parseHeader(write.data) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient) + extHdr, err := unpackHeader(opener, replyHdr, write.data, hdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, write.data[extHdr.ParsedLen():], extHdr.PacketNumber, write.data[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + }) + + It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + It("creates a session, if no Token is required", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } hdr := &wire.Header{