upstream: sync to 0.39.1

This commit is contained in:
Gaukas Wang 2023-10-26 22:45:37 -06:00
commit 7c77243b04
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
147 changed files with 3740 additions and 2211 deletions

112
server.go
View file

@ -2,7 +2,6 @@ package quic
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
@ -60,7 +59,8 @@ type zeroRTTQueue struct {
type baseServer struct {
mutex sync.Mutex
acceptEarlyConns bool
disableVersionNegotiation bool
acceptEarlyConns bool
tlsConf *tls.Config
config *Config
@ -68,6 +68,7 @@ type baseServer struct {
conn rawConn
tokenGenerator *handshake.TokenGenerator
maxTokenAge time.Duration
connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager
@ -93,7 +94,7 @@ type baseServer struct {
*tls.Config,
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
logging.ConnectionTracer,
*logging.ConnectionTracer,
uint64,
utils.Logger,
protocol.VersionNumber,
@ -109,7 +110,7 @@ type baseServer struct {
connQueue chan quicConn
connQueueLen int32 // to be used as an atomic
tracer logging.Tracer
tracer *logging.Tracer
logger utils.Logger
}
@ -225,32 +226,33 @@ func newServer(
connIDGenerator ConnectionIDGenerator,
tlsConf *tls.Config,
config *Config,
tracer logging.Tracer,
tracer *logging.Tracer,
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
disableVersionNegotiation bool,
acceptEarly bool,
) (*baseServer, error) {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
}
) *baseServer {
s := &baseServer{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan receivedPacket, 4),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
onClose: onClose,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan receivedPacket, 4),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@ -258,7 +260,7 @@ func newServer(
go s.run()
go s.runSendQueue()
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
return s
}
func (s *baseServer) run() {
@ -351,7 +353,7 @@ func (s *baseServer) handlePacket(p receivedPacket) {
case s.receivedPackets <- p:
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
@ -364,7 +366,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.")
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
@ -377,20 +379,20 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// drop the packet if we failed to parse the protocol version
if err != nil {
s.logger.Debugf("Dropping a packet with an unknown version")
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, v) {
if s.config.DisableVersionNegotiationPackets {
if s.disableVersionNegotiation {
return false
}
if p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
@ -400,7 +402,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
@ -412,7 +414,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
@ -420,7 +422,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
}
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
@ -431,7 +433,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// There's little point in sending a Stateless Reset, since the client
// might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
@ -450,7 +452,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
@ -464,7 +466,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
@ -474,7 +476,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
@ -502,7 +504,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
continue
}
for _, p := range q.packets {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
@ -526,10 +528,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
if !token.ValidateRemoteAddr(addr) {
return false
}
if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge {
if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge {
return false
}
if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge {
if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() {
return false
}
return true
@ -538,7 +540,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release()
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return errors.New("too short connection ID")
@ -623,7 +625,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
config = populateConfig(conf)
}
var tracer logging.ConnectionTracer
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
@ -740,10 +742,10 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
// append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...)
if s.tracer != nil {
if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -761,7 +763,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
@ -776,14 +778,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
}
return
}
hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
}
return
@ -839,10 +841,10 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true)
if s.tracer != nil {
if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -868,7 +870,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.tracer != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return
@ -877,10 +879,10 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.tracer != nil {
if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
}