send an INVALID_TOKEN error when receiving an invalid token

This commit is contained in:
Marten Seemann 2020-02-24 11:18:36 +07:00
parent 3e083d19f4
commit e57caf0bae
5 changed files with 153 additions and 16 deletions

View file

@ -455,4 +455,30 @@ var _ = Describe("Handshake tests", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
}) })
It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
tokenChan := make(chan *quic.Token, 10)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
tokenChan <- token
return false
}
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
_, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
nil,
)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("INVALID_TOKEN"))
Expect(tokenChan).To(HaveLen(2))
token := <-tokenChan
Expect(token).To(BeNil())
token = <-tokenChan
Expect(token).ToNot(BeNil())
Expect(token.IsRetryToken).To(BeTrue())
})
}) })

View file

@ -22,6 +22,7 @@ const (
TransportParameterError ErrorCode = 0x8 TransportParameterError ErrorCode = 0x8
ConnectionIDLimitError ErrorCode = 0x9 ConnectionIDLimitError ErrorCode = 0x9
ProtocolViolation ErrorCode = 0xa ProtocolViolation ErrorCode = 0xa
InvalidToken ErrorCode = 0xb
CryptoBufferExceeded ErrorCode = 0xd CryptoBufferExceeded ErrorCode = 0xd
) )
@ -69,6 +70,8 @@ func (e ErrorCode) String() string {
return "CONNECTION_ID_LIMIT_ERROR" return "CONNECTION_ID_LIMIT_ERROR"
case ProtocolViolation: case ProtocolViolation:
return "PROTOCOL_VIOLATION" return "PROTOCOL_VIOLATION"
case InvalidToken:
return "INVALID_TOKEN"
case CryptoBufferExceeded: case CryptoBufferExceeded:
return "CRYPTO_BUFFER_EXCEEDED" return "CRYPTO_BUFFER_EXCEEDED"
default: default:

View file

@ -102,7 +102,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
} }
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpack(opener, hdr, data) extHdr, parseErr := u.unpackHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
@ -126,7 +126,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(
rcvTime time.Time, rcvTime time.Time,
data []byte, data []byte,
) (*wire.ExtendedHeader, []byte, error) { ) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpack(opener, hdr, data) extHdr, parseErr := u.unpackHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking. // If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker // This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption. // to gain information about the header encryption.
@ -144,7 +144,20 @@ func (u *packetUnpacker) unpackShortHeaderPacket(
return extHdr, decrypted, nil return extHdr, decrypted, nil
} }
func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackHeader(hd, hdr, data, u.version)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, err
}
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, err
}
func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data) r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen() hdrLen := hdr.ParsedLen()
@ -163,7 +176,7 @@ func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byt
data[hdrLen:hdrLen+4], data[hdrLen:hdrLen+4],
) )
// 3. parse the header (and learn the actual length of the packet number) // 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, u.version) extHdr, parseErr := hdr.ParseExtended(r, version)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr return nil, parseErr
} }
@ -171,11 +184,5 @@ func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byt
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
} }
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, parseErr return extHdr, parseErr
} }

View file

@ -363,6 +363,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
} }
if !s.config.AcceptToken(p.remoteAddr, token) { if !s.config.AcceptToken(p.remoteAddr, token) {
go func() { go func() {
if token != nil && token.IsRetryToken {
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
return
}
if err := s.sendRetry(p.remoteAddr, hdr); err != nil { if err := s.sendRetry(p.remoteAddr, hdr); err != nil {
s.logger.Debugf("Error sending Retry: %s", err) s.logger.Debugf("Error sending Retry: %s", err)
} }
@ -512,13 +518,39 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return err return err
} }
func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error {
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackHeader(opener, hdr, data, hdr.Version)
if err != nil {
// don't return the error here. Just drop the packet.
return nil
}
hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
// don't return the error here. Just drop the packet.
return nil
}
if s.logger.Debug() {
s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
}
return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken)
}
func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
return s.sendError(remoteAddr, hdr, sealer, qerr.ServerBusy)
}
// sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode) error {
packetBuffer := getPacketBuffer() packetBuffer := getPacketBuffer()
defer packetBuffer.Release() defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Data) buf := bytes.NewBuffer(packetBuffer.Data)
ccf := &wire.ConnectionCloseFrame{ErrorCode: qerr.ServerBusy} ccf := &wire.ConnectionCloseFrame{ErrorCode: errorCode}
replyHdr := &wire.ExtendedHeader{} replyHdr := &wire.ExtendedHeader{}
replyHdr.IsLongHeader = true replyHdr.IsLongHeader = true

View file

@ -14,6 +14,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/qlog"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -40,15 +42,27 @@ var _ = Describe("Server", func() {
tlsConf *tls.Config tlsConf *tls.Config
) )
getPacket := func(hdr *wire.Header, data []byte) *receivedPacket { getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
buf := &bytes.Buffer{} buffer := getPacketBuffer()
buf := bytes.NewBuffer(buffer.Data)
if hdr.IsLongHeader {
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
}
Expect((&wire.ExtendedHeader{ Expect((&wire.ExtendedHeader{
Header: *hdr, Header: *hdr,
PacketNumberLen: protocol.PacketNumberLen3, PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.VersionTLS)).To(Succeed()) }).Write(buf, protocol.VersionTLS)).To(Succeed())
n := buf.Len()
buf.Write(p)
data := buffer.Data[:buf.Len()]
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient)
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
data = data[:len(data)+16]
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
return &receivedPacket{ return &receivedPacket{
data: append(buf.Bytes(), data...), data: data,
buffer: getPacketBuffer(), buffer: buffer,
} }
} }
@ -321,6 +335,61 @@ var _ = Describe("Server", func() {
Expect(write.data[len(write.data)-16:]).To(Equal(handshake.GetRetryIntegrityTag(write.data[:len(write.data)-16], hdr.DestConnectionID)[:])) Expect(write.data[len(write.data)-16:]).To(Equal(handshake.GetRetryIntegrityTag(write.data[:len(write.data)-16], hdr.DestConnectionID)[:]))
}) })
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
replyHdr := parseHeader(write.data)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
_, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient)
extHdr, err := unpackHeader(opener, replyHdr, write.data, hdr.Version)
Expect(err).ToNot(HaveOccurred())
data, err := opener.Open(nil, write.data[extHdr.ParsedLen():], extHdr.PacketNumber, write.data[:extHdr.ParsedLen()])
Expect(err).ToNot(HaveOccurred())
f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
ccf := f.(*wire.ConnectionCloseFrame)
Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken))
Expect(ccf.ReasonPhrase).To(BeEmpty())
})
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Token: token,
Version: protocol.VersionTLS,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
serv.handlePacket(packet)
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("creates a session, if no Token is required", func() { It("creates a session, if no Token is required", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
hdr := &wire.Header{ hdr := &wire.Header{