Merge pull request #2524 from lucas-clemente/fix-buffer-use-after-release

fix buffer use after it was released when sending an INVALID_TOKEN error
This commit is contained in:
Marten Seemann 2020-05-05 19:14:39 +07:00 committed by GitHub
commit 2e402ffc86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 24 deletions

View file

@ -296,7 +296,7 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
} }
} }
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet handled */ { func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* should the buffer be released */ {
// If we're creating a new session, the packet will be passed to the session. // If we're creating a new session, the packet will be passed to the session.
// The header will then be parsed again. // The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
@ -332,24 +332,18 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
s.logger.Debugf("<- Received Initial packet.") s.logger.Debugf("<- Received Initial packet.")
sess, err := s.handleInitialImpl(p, hdr) if err := s.handleInitialImpl(p, hdr); err != nil {
if err != nil {
s.logger.Errorf("Error occurred handling initial packet: %s", err) s.logger.Errorf("Error occurred handling initial packet: %s", err)
return false
} }
// A retry was done, or the connection attempt was rejected, // Don't put the packet buffer back.
// or if the Initial was a duplicate. // handleInitialImpl deals with the buffer.
if sess == nil {
return false
}
// Don't put the packet buffer back if a new session was created.
// The session will handle the packet and take of that.
return true return true
} }
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, error) { func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, errors.New("too short connection ID") p.buffer.Release()
return errors.New("too short connection ID")
} }
var token *Token var token *Token
@ -367,6 +361,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
} }
if !s.config.AcceptToken(p.remoteAddr, token) { if !s.config.AcceptToken(p.remoteAddr, token) {
go func() { go func() {
defer p.buffer.Release()
if token != nil && token.IsRetryToken { if token != nil && token.IsRetryToken {
if err := s.maybeSendInvalidToken(p, hdr); err != nil { if err := s.maybeSendInvalidToken(p, hdr); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
@ -377,7 +372,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
s.logger.Debugf("Error sending Retry: %s", err) s.logger.Debugf("Error sending Retry: %s", err)
} }
}() }()
return nil, nil return nil
} }
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
@ -387,12 +382,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
s.logger.Debugf("Error rejecting connection: %s", err) s.logger.Debugf("Error rejecting connection: %s", err)
} }
}() }()
return nil, nil return nil
} }
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil { if err != nil {
return nil, err return err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
sess := s.createNewSession( sess := s.createNewSession(
@ -404,7 +399,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
hdr.Version, hdr.Version,
) )
if sess == nil { if sess == nil {
return nil, nil p.buffer.Release()
return nil
} }
sess.handlePacket(p) sess.handlePacket(p)
for { for {
@ -414,7 +410,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
} }
sess.handlePacket(p) sess.handlePacket(p)
} }
return sess, nil return nil
} }
func (s *baseServer) createNewSession( func (s *baseServer) createNewSession(

View file

@ -14,18 +14,17 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/qlog"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qlog"
"github.com/lucas-clemente/quic-go/quictrace" "github.com/lucas-clemente/quic-go/quictrace"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -580,7 +579,7 @@ var _ = Describe("Server", func() {
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeFalse()) Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdSession).To(BeTrue()) Expect(createdSession).To(BeTrue())
}) })