mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
upstream: sync to 0.39.1
This commit is contained in:
commit
7c77243b04
147 changed files with 3740 additions and 2211 deletions
112
server.go
112
server.go
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -60,7 +59,8 @@ type zeroRTTQueue struct {
|
|||
type baseServer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
acceptEarlyConns bool
|
||||
disableVersionNegotiation bool
|
||||
acceptEarlyConns bool
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
@ -68,6 +68,7 @@ type baseServer struct {
|
|||
conn rawConn
|
||||
|
||||
tokenGenerator *handshake.TokenGenerator
|
||||
maxTokenAge time.Duration
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
connHandler packetHandlerManager
|
||||
|
@ -93,7 +94,7 @@ type baseServer struct {
|
|||
*tls.Config,
|
||||
*handshake.TokenGenerator,
|
||||
bool, /* client address validated by an address validation token */
|
||||
logging.ConnectionTracer,
|
||||
*logging.ConnectionTracer,
|
||||
uint64,
|
||||
utils.Logger,
|
||||
protocol.VersionNumber,
|
||||
|
@ -109,7 +110,7 @@ type baseServer struct {
|
|||
connQueue chan quicConn
|
||||
connQueueLen int32 // to be used as an atomic
|
||||
|
||||
tracer logging.Tracer
|
||||
tracer *logging.Tracer
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -225,32 +226,33 @@ func newServer(
|
|||
connIDGenerator ConnectionIDGenerator,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
tracer logging.Tracer,
|
||||
tracer *logging.Tracer,
|
||||
onClose func(),
|
||||
tokenGeneratorKey TokenGeneratorKey,
|
||||
maxTokenAge time.Duration,
|
||||
disableVersionNegotiation bool,
|
||||
acceptEarly bool,
|
||||
) (*baseServer, error) {
|
||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
) *baseServer {
|
||||
s := &baseServer{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: tokenGenerator,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||
invalidTokenQueue: make(chan receivedPacket, 4),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
onClose: onClose,
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||
maxTokenAge: maxTokenAge,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||
invalidTokenQueue: make(chan receivedPacket, 4),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
disableVersionNegotiation: disableVersionNegotiation,
|
||||
onClose: onClose,
|
||||
}
|
||||
if acceptEarly {
|
||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||
|
@ -258,7 +260,7 @@ func newServer(
|
|||
go s.run()
|
||||
go s.runSendQueue()
|
||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
return s, nil
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *baseServer) run() {
|
||||
|
@ -351,7 +353,7 @@ func (s *baseServer) handlePacket(p receivedPacket) {
|
|||
case s.receivedPackets <- p:
|
||||
default:
|
||||
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
}
|
||||
|
@ -364,7 +366,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
s.logger.Debugf("Dropping Version Negotiation packet.")
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
|
@ -377,20 +379,20 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
// drop the packet if we failed to parse the protocol version
|
||||
if err != nil {
|
||||
s.logger.Debugf("Dropping a packet with an unknown version")
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, v) {
|
||||
if s.config.DisableVersionNegotiationPackets {
|
||||
if s.disableVersionNegotiation {
|
||||
return false
|
||||
}
|
||||
|
||||
if p.Size() < protocol.MinUnknownVersionPacketSize {
|
||||
s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
|
@ -400,7 +402,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
if !s.acceptEarlyConns {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
|
@ -412,7 +414,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Error parsing packet: %s", err)
|
||||
|
@ -420,7 +422,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
|
||||
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
|
@ -431,7 +433,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
// There's little point in sending a Stateless Reset, since the client
|
||||
// might not have received the token yet.
|
||||
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
|
@ -450,7 +452,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
|
|||
func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
return false
|
||||
|
@ -464,7 +466,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
|
|||
|
||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
||||
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
|
@ -474,7 +476,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
|
|||
}
|
||||
|
||||
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
|
@ -502,7 +504,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
|
|||
continue
|
||||
}
|
||||
for _, p := range q.packets {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
p.buffer.Release()
|
||||
|
@ -526,10 +528,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
|||
if !token.ValidateRemoteAddr(addr) {
|
||||
return false
|
||||
}
|
||||
if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge {
|
||||
if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge {
|
||||
return false
|
||||
}
|
||||
if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge {
|
||||
if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
@ -538,7 +540,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
|||
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
p.buffer.Release()
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return errors.New("too short connection ID")
|
||||
|
@ -623,7 +625,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
|||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
var tracer logging.ConnectionTracer
|
||||
var tracer *logging.ConnectionTracer
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
|
@ -740,10 +742,10 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
|
|||
// append the Retry integrity tag
|
||||
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
|
||||
buf.Data = append(buf.Data, tag[:]...)
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.SentPacket != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
}
|
||||
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
|
||||
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -761,7 +763,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
|
|||
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Error parsing packet: %s", err)
|
||||
|
@ -776,14 +778,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
|
|||
// Only send INVALID_TOKEN if we can unprotect the packet.
|
||||
// This makes sure that we won't send it for packets that were corrupted.
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
return
|
||||
}
|
||||
hdrLen := extHdr.ParsedLen()
|
||||
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
}
|
||||
return
|
||||
|
@ -839,10 +841,10 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
|
|||
|
||||
replyHdr.Log(s.logger)
|
||||
wire.LogFrame(s.logger, ccf, true)
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.SentPacket != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
}
|
||||
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
|
||||
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -868,7 +870,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
|
|||
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
|
||||
if err != nil { // should never happen
|
||||
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return
|
||||
|
@ -877,10 +879,10 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
|
|||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
|
||||
|
||||
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
|
||||
if s.tracer != nil {
|
||||
if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil {
|
||||
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
|
||||
}
|
||||
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
|
||||
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue