Big update to Update quic-go

This commit is contained in:
Frank Denis 2023-07-22 00:41:27 +02:00
parent a4eda39563
commit d659a801c2
92 changed files with 2442 additions and 13388 deletions

View file

@ -1,4 +1,6 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
depguard:
type: blacklist

View file

@ -12,10 +12,6 @@ In addition to these base RFCs, it also implements the following RFCs:
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
This repository provides both a QUIC implementation, located in the `quic` package, as well as an HTTP/3 implementation, located in the `http3` package.
## Using QUIC
### Running a Server
@ -136,7 +132,7 @@ The `quic.Transport` contains a few configuration options that don't apply to an
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Users can use errors assertions to find out what exactly went wrong:
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
@ -224,7 +220,8 @@ quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing

View file

@ -8,6 +8,7 @@ coverage:
- http3/gzip_reader.go
- interop/
- internal/ackhandler/packet_linkedlist.go
- internal/handshake/cipher_suite.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/newconnectionid_linkedlist.go
- internal/utils/packetinterval_linkedlist.go

View file

@ -52,11 +52,13 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
RunHandshake()
StartHandshake() error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
GetSessionTicket() ([]byte, error)
NextEvent() handshake.Event
DiscardInitialKeys()
io.Closer
ConnectionState() handshake.ConnectionState
}
@ -96,18 +98,6 @@ type connRunner interface {
RemoveResetToken(protocol.StatelessResetToken)
}
type handshakeRunner struct {
onReceivedParams func(*wire.TransportParameters)
onError func(error)
dropKeys func(protocol.EncryptionLevel)
onHandshakeComplete func()
}
func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) }
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
type closeError struct {
err error
remote bool
@ -165,6 +155,8 @@ type connection struct {
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
@ -176,19 +168,17 @@ type connection struct {
closeChan chan closeError
ctx context.Context
ctxCancel context.CancelFunc
ctxCancel context.CancelCauseFunc
handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
clientHelloWritten <-chan *wire.TransportParameters
earlyConnReadyChan chan struct{}
handshakeCompleteChan chan struct{} // is closed when the handshake completes
sentFirstPacket bool
handshakeComplete bool
handshakeConfirmed bool
earlyConnReadyChan chan struct{}
sentFirstPacket bool
handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool
versionNegotiated bool
@ -248,17 +238,16 @@ var newConnection = func(
v protocol.VersionNumber,
) quicConn {
s := &connection{
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}),
tracer: tracer,
logger: logger,
version: v,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
tracer: tracer,
logger: logger,
version: v,
}
if origDestConnID.Len() > 0 {
s.logID = origDestConnID.String()
@ -283,7 +272,7 @@ var newConnection = func(
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
getMaxPacketSize(s.conn.RemoteAddr()),
@ -294,8 +283,6 @@ var newConnection = func(
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -327,21 +314,8 @@ var newConnection = func(
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
clientDestConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onHandshakeComplete: func() {
runner.Retire(clientDestConnID)
close(s.handshakeCompleteChan)
},
},
tlsConf,
conf.Allow0RTT,
s.rttStats,
@ -350,9 +324,9 @@ var newConnection = func(
s.version,
)
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, s.oneRTTStream)
return s
}
@ -374,18 +348,17 @@ var newClientConnection = func(
v protocol.VersionNumber,
) quicConn {
s := &connection{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
handshakeCompleteChan: make(chan struct{}),
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
@ -405,7 +378,7 @@ var newClientConnection = func(
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
@ -416,8 +389,7 @@ var newClientConnection = func(
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -445,19 +417,9 @@ var newClientConnection = func(
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream,
handshakeStream,
cs := handshake.NewCryptoSetupClient(
destConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
},
tlsConf,
enable0RTT,
s.rttStats,
@ -465,11 +427,10 @@ var newClientConnection = func(
logger,
s.version,
)
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
@ -484,6 +445,8 @@ var newClientConnection = func(
}
func (s *connection) preSetup() {
s.initialStream = newCryptoStream()
s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
@ -526,15 +489,19 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
defer s.ctxCancel()
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
s.timer = *newTimer()
handshaking := make(chan struct{})
go func() {
defer close(handshaking)
s.cryptoStreamHandler.RunHandshake()
}()
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
return err
}
if err := s.handleHandshakeEvents(); err != nil {
return err
}
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
@ -542,23 +509,10 @@ func (s *connection) run() error {
}()
if s.perspective == protocol.PerspectiveClient {
select {
case zeroRTTParams := <-s.clientHelloWritten:
s.scheduleSending()
if zeroRTTParams != nil {
s.restoreTransportParameters(zeroRTTParams)
close(s.earlyConnReadyChan)
}
case closeErr := <-s.closeChan:
// put the close error back into the channel, so that the run loop can receive it
s.closeChan <- closeErr
}
s.scheduleSending() // so the ClientHello actually gets sent
}
var (
closeErr closeError
sendQueueAvailable <-chan struct{}
)
var sendQueueAvailable <-chan struct{}
runLoop:
for {
@ -566,8 +520,6 @@ runLoop:
select {
case closeErr = <-s.closeChan:
break runLoop
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
default:
}
@ -638,8 +590,6 @@ runLoop:
if !wasProcessed {
continue
}
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
}
@ -686,7 +636,6 @@ runLoop:
}
s.cryptoStreamHandler.Close()
<-handshaking
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
@ -717,7 +666,9 @@ func (s *connection) supportsDatagrams() bool {
func (s *connection) ConnectionState() ConnectionState {
s.connStateMutex.Lock()
defer s.connStateMutex.Unlock()
s.connState.TLS = s.cryptoStreamHandler.ConnectionState()
cs := s.cryptoStreamHandler.ConnectionState()
s.connState.TLS = cs.ConnectionState
s.connState.Used0RTT = cs.Used0RTT
return s.connState
}
@ -764,9 +715,8 @@ func (s *connection) idleTimeoutStartTime() time.Time {
return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime)
}
func (s *connection) handleHandshakeComplete() {
func (s *connection) handleHandshakeComplete() error {
s.handshakeComplete = true
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
defer s.handshakeCtxCancel()
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption any more.
@ -777,16 +727,18 @@ func (s *connection) handleHandshakeComplete() {
if s.perspective == protocol.PerspectiveClient {
s.applyTransportParameters()
return
return nil
}
s.handleHandshakeConfirmed()
if err := s.handleHandshakeConfirmed(); err != nil {
return err
}
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
s.closeLocal(err)
return err
}
if ticket != nil {
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
s.oneRTTStream.Write(ticket)
for s.oneRTTStream.HasData() {
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
@ -794,13 +746,18 @@ func (s *connection) handleHandshakeComplete() {
}
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
if err != nil {
s.closeLocal(err)
return err
}
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
s.queueControlFrame(&wire.HandshakeDoneFrame{})
return nil
}
func (s *connection) handleHandshakeConfirmed() {
func (s *connection) handleHandshakeConfirmed() error {
if err := s.dropEncryptionLevel(protocol.EncryptionHandshake); err != nil {
return err
}
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed()
@ -812,6 +769,7 @@ func (s *connection) handleHandshakeConfirmed() {
}
s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize))
}
return nil
}
func (s *connection) handlePacketImpl(rp receivedPacket) bool {
@ -1213,6 +1171,14 @@ func (s *connection) handleUnpackedLongHeaderPacket(
}
}
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
return err
}
}
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
@ -1378,16 +1344,41 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame
}
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
if err != nil {
if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil {
return err
}
if encLevelChanged {
// Queue all packets for decryption that have been undecryptable so far.
s.undecryptablePacketsToProcess = s.undecryptablePackets
s.undecryptablePackets = nil
return s.handleHandshakeEvents()
}
func (s *connection) handleHandshakeEvents() error {
for {
ev := s.cryptoStreamHandler.NextEvent()
var err error
switch ev.Kind {
case handshake.EventNoEvent:
return nil
case handshake.EventHandshakeComplete:
err = s.handleHandshakeComplete()
case handshake.EventReceivedTransportParameters:
err = s.handleTransportParameters(ev.TransportParameters)
case handshake.EventRestoredTransportParameters:
s.restoreTransportParameters(ev.TransportParameters)
close(s.earlyConnReadyChan)
case handshake.EventReceivedReadKeys:
// Queue all packets for decryption that have been undecryptable so far.
s.undecryptablePacketsToProcess = s.undecryptablePackets
s.undecryptablePackets = nil
case handshake.EventDiscard0RTTKeys:
err = s.dropEncryptionLevel(protocol.Encryption0RTT)
case handshake.EventWriteInitialData:
_, err = s.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData:
_, err = s.handshakeStream.Write(ev.Data)
}
if err != nil {
return err
}
}
return nil
}
func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error {
@ -1496,7 +1487,9 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr
return nil
}
if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed {
s.handleHandshakeConfirmed()
if err := s.handleHandshakeConfirmed(); err != nil {
return err
}
}
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
@ -1628,21 +1621,24 @@ func (s *connection) handleCloseError(closeErr *closeError) {
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
if s.tracer != nil {
s.tracer.DroppedEncryptionLevel(encLevel)
}
if encLevel == protocol.Encryption0RTT {
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
//nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel {
case protocol.EncryptionInitial:
s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT()
if err := s.connFlowController.Reset(); err != nil {
s.closeLocal(err)
}
if err := s.framer.Handle0RTTRejection(); err != nil {
s.closeLocal(err)
return err
}
return s.framer.Handle0RTTRejection()
}
return s.cryptoStreamManager.Drop(encLevel)
}
// is called for the client, when restoring transport parameters saved for 0-RTT
@ -1660,13 +1656,12 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters
s.connStateMutex.Unlock()
}
func (s *connection) handleTransportParameters(params *wire.TransportParameters) {
func (s *connection) handleTransportParameters(params *wire.TransportParameters) error {
if err := s.checkTransportParameters(params); err != nil {
s.closeLocal(&qerr.TransportError{
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
return
}
}
s.peerParams = params
// On the client side we have to wait for handshake completion.
@ -1681,6 +1676,7 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters)
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
return nil
}
func (s *connection) checkTransportParameters(params *wire.TransportParameters) error {
@ -1817,6 +1813,9 @@ func (s *connection) sendPackets(now time.Time) error {
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
}
s.windowUpdateQueue.QueueAll()
if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
s.queueControlFrame(cf)
}
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
@ -1824,7 +1823,9 @@ func (s *connection) sendPackets(now time.Time) error {
return err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
if err := s.sendPackedCoalescedPacket(packet, now); err != nil {
return err
}
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
@ -1944,8 +1945,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if packet == nil {
return nil
}
s.sendPackedCoalescedPacket(packet, time.Now())
return nil
return s.sendPackedCoalescedPacket(packet, time.Now())
}
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
@ -1989,8 +1989,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
}
s.sendPackedCoalescedPacket(packet, now)
return nil
return s.sendPackedCoalescedPacket(packet, now)
}
// appendPacket appends a new packet to the given packetBuffer.
@ -2020,7 +2019,7 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti
s.connIDManager.SentPacket()
}
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) {
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error {
s.logCoalescedPacket(packet)
for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
@ -2031,6 +2030,13 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
return err
}
}
}
if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
@ -2044,6 +2050,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len())
return nil
}
func (s *connection) sendConnectionClose(e error) ([]byte, error) {
@ -2078,6 +2085,9 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
@ -2086,6 +2096,9 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
for _, f := range p.frames {
frames = append(frames, logutils.ConvertFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = logutils.ConvertAckFrame(p.ack)
@ -2296,11 +2309,11 @@ func (s *connection) SendMessage(p []byte) error {
return s.datagramQueue.AddAndWait(f)
}
func (s *connection) ReceiveMessage() ([]byte, error) {
func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) {
if !s.config.EnableDatagrams {
return nil, errors.New("datagram support disabled")
}
return s.datagramQueue.Receive()
return s.datagramQueue.Receive(ctx)
}
func (s *connection) LocalAddr() net.Addr {

View file

@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
b := s.msgBuf
s.msgBuf = nil
return b
}
func (s *cryptoStreamImpl) Finish() error {

View file

@ -3,12 +3,14 @@ package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
@ -33,7 +35,7 @@ func newCryptoStreamManager(
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
@ -44,18 +46,37 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return false, err
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return false, nil
return nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}

View file

@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive() ([]byte, error) {
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) {
continue
case <-h.closed:
return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
}

View file

@ -15,7 +15,6 @@ import (
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
@ -328,31 +327,43 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
return rsp, rerr.err
}
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
defer body.Close()
b := make([]byte, bodyCopyBufferSize)
for {
n, rerr := body.Read(b)
if n == 0 {
if rerr == nil {
continue
}
if rerr == io.EOF {
break
}
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
if rerr != nil {
if rerr == io.EOF {
break
}
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
return rerr
}
// cancelingReader reads from the io.Reader.
// It cancels writing on the stream if any error other than io.EOF occurs.
type cancelingReader struct {
r io.Reader
str Stream
}
func (r *cancelingReader) Read(b []byte) (int, error) {
n, err := r.r.Read(b)
if err != nil && err != io.EOF {
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
}
return nil
return n, err
}
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
defer body.Close()
buf := make([]byte, bodyCopyBufferSize)
sr := &cancelingReader{str: str, r: body}
if contentLength == -1 {
_, err := io.CopyBuffer(str, sr, buf)
return err
}
// make sure we don't send more bytes than the content length
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
if err != nil {
return err
}
var extra int64
extra, err = io.CopyBuffer(io.Discard, sr, buf)
n += extra
if n > contentLength {
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
}
return err
}
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
@ -372,7 +383,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
if req.Body != nil {
// send the request body asynchronously
go func() {
if err := c.sendRequestBody(hstr, req.Body); err != nil {
contentLength := int64(-1)
// According to the documentation for http.Request.ContentLength,
// a value of 0 with a non-nil Body is also treated as unknown content length.
if req.ContentLength > 0 {
contentLength = req.ContentLength
}
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
c.logger.Errorf("Error writing request: %s", err)
}
if !opt.DontCloseRequestStream {
@ -402,28 +419,22 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
return nil, newConnError(ErrCodeGeneralProtocolError, err)
}
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
res := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
Header: http.Header{},
TLS: &connState,
Request: req,
res, err := responseFromHeaders(hfs)
if err != nil {
return nil, newStreamError(ErrCodeMessageError, err)
}
for _, hf := range hfs {
switch hf.Name {
case ":status":
status, err := strconv.Atoi(hf.Value)
if err != nil {
return nil, newStreamError(ErrCodeGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
}
res.StatusCode = status
res.Status = hf.Value + " " + http.StatusText(status)
default:
res.Header.Add(hf.Name, hf.Value)
}
connState := conn.ConnectionState().TLS
res.TLS = &connState
res.Request = req
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(hstr, req.ContentLength)
} else {
httpStr = hstr
}
respBody := newResponseBody(hstr, conn, reqDone)
respBody := newResponseBody(httpStr, conn, reqDone)
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
_, hasTransferEncoding := res.Header["Transfer-Encoding"]

198
vendor/github.com/quic-go/quic-go/http3/headers.go generated vendored Normal file
View file

@ -0,0 +1,198 @@
package http3
import (
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
"github.com/quic-go/qpack"
)
type header struct {
// Pseudo header fields defined in RFC 9114
Path string
Method string
Authority string
Scheme string
Status string
// for Extended connect
Protocol string
// parsed and deduplicated
ContentLength int64
// all non-pseudo headers
Headers http.Header
}
func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
hdr := header{Headers: make(http.Header, len(headers))}
var readFirstRegularHeader, readContentLength bool
var contentLengthStr string
for _, h := range headers {
// field names need to be lowercase, see section 4.2 of RFC 9114
if strings.ToLower(h.Name) != h.Name {
return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name)
}
if !httpguts.ValidHeaderFieldValue(h.Value) {
return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
}
if h.IsPseudo() {
if readFirstRegularHeader {
// all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114
return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
}
var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses
switch h.Name {
case ":path":
hdr.Path = h.Value
case ":method":
hdr.Method = h.Value
case ":authority":
hdr.Authority = h.Value
case ":protocol":
hdr.Protocol = h.Value
case ":scheme":
hdr.Scheme = h.Value
case ":status":
hdr.Status = h.Value
isResponsePseudoHeader = true
default:
return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name)
}
if isRequest && isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name)
}
if !isRequest && !isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name)
}
} else {
if !httpguts.ValidHeaderFieldName(h.Name) {
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
}
readFirstRegularHeader = true
switch h.Name {
case "content-length":
// Ignore duplicate Content-Length headers.
// Fail if the duplicates differ.
if !readContentLength {
readContentLength = true
contentLengthStr = h.Value
} else if contentLengthStr != h.Value {
return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value)
}
default:
hdr.Headers.Add(h.Name, h.Value)
}
}
}
if len(contentLengthStr) > 0 {
// use ParseUint instead of ParseInt, so that parsing fails on negative values
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
if err != nil {
return header{}, fmt.Errorf("invalid content length: %w", err)
}
hdr.Headers.Set("Content-Length", contentLengthStr)
hdr.ContentLength = int64(cl)
}
return hdr, nil
}
func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) {
hdr, err := parseHeaders(headerFields, true)
if err != nil {
return nil, err
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(hdr.Headers["Cookie"]) > 0 {
hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; "))
}
isConnect := hdr.Method == http.MethodConnect
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
isExtendedConnected := isConnect && hdr.Protocol != ""
if isExtendedConnected {
if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" {
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
}
} else if isConnect {
if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT
return nil, errors.New(":path must be empty and :authority must not be empty")
}
} else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
var u *url.URL
var requestURI string
var protocol string
if isConnect {
u = &url.URL{}
if isExtendedConnected {
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, err
}
} else {
u.Path = hdr.Path
}
u.Scheme = hdr.Scheme
u.Host = hdr.Authority
requestURI = hdr.Authority
protocol = hdr.Protocol
} else {
protocol = "HTTP/3.0"
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, fmt.Errorf("invalid content length: %w", err)
}
requestURI = hdr.Path
}
return &http.Request{
Method: hdr.Method,
URL: u,
Proto: protocol,
ProtoMajor: 3,
ProtoMinor: 0,
Header: hdr.Headers,
Body: nil,
ContentLength: hdr.ContentLength,
Host: hdr.Authority,
RequestURI: requestURI,
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}
func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) {
hdr, err := parseHeaders(headerFields, false)
if err != nil {
return nil, err
}
if hdr.Status == "" {
return nil, errors.New("missing status field")
}
rsp := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
Header: hdr.Headers,
ContentLength: hdr.ContentLength,
}
status, err := strconv.Atoi(hdr.Status)
if err != nil {
return nil, fmt.Errorf("invalid status code: %w", err)
}
rsp.StatusCode = status
rsp.Status = hdr.Status + " " + http.StatusText(status)
return rsp, nil
}

View file

@ -1,9 +1,11 @@
package http3
import (
"errors"
"fmt"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
)
// A Stream is a HTTP/3 stream.
@ -66,6 +68,10 @@ func (s *stream) Read(b []byte) (int, error) {
return n, err
}
func (s *stream) hasMoreData() bool {
return s.bytesRemainingInFrame > 0
}
func (s *stream) Write(b []byte) (int, error) {
s.buf = s.buf[:0]
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
@ -74,3 +80,45 @@ func (s *stream) Write(b []byte) (int, error) {
}
return s.Stream.Write(b)
}
var errTooMuchData = errors.New("peer sent too much data")
type lengthLimitedStream struct {
*stream
contentLength int64
read int64
resetStream bool
}
var _ Stream = &lengthLimitedStream{}
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
return &lengthLimitedStream{
stream: str,
contentLength: contentLength,
}
}
func (s *lengthLimitedStream) checkContentLengthViolation() error {
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
if !s.resetStream {
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
s.resetStream = true
}
return errTooMuchData
}
return nil
}
func (s *lengthLimitedStream) Read(b []byte) (int, error) {
if err := s.checkContentLengthViolation(); err != nil {
return 0, err
}
n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)])
s.read += int64(n)
if err := s.checkContentLengthViolation(); err != nil {
return n, err
}
return n, err
}

View file

@ -1,111 +0,0 @@
package http3
import (
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/quic-go/qpack"
)
func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
var path, authority, method, protocol, scheme, contentLengthStr string
httpHeaders := http.Header{}
for _, h := range headers {
switch h.Name {
case ":path":
path = h.Value
case ":method":
method = h.Value
case ":authority":
authority = h.Value
case ":protocol":
protocol = h.Value
case ":scheme":
scheme = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
}
}
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
isConnect := method == http.MethodConnect
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
isExtendedConnected := isConnect && protocol != ""
if isExtendedConnected {
if scheme == "" || path == "" || authority == "" {
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
}
} else if isConnect {
if path != "" || authority == "" { // normal CONNECT
return nil, errors.New(":path must be empty and :authority must not be empty")
}
} else if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
var u *url.URL
var requestURI string
var err error
if isConnect {
u = &url.URL{}
if isExtendedConnected {
u, err = url.ParseRequestURI(path)
if err != nil {
return nil, err
}
} else {
u.Path = path
}
u.Scheme = scheme
u.Host = authority
requestURI = authority
} else {
protocol = "HTTP/3.0"
u, err = url.ParseRequestURI(path)
if err != nil {
return nil, err
}
requestURI = path
}
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: method,
URL: u,
Proto: protocol,
ProtoMajor: 3,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
ContentLength: contentLength,
Host: authority,
RequestURI: requestURI,
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}

View file

@ -2,6 +2,7 @@ package http3
import (
"bytes"
"errors"
"fmt"
"io"
"net"
@ -81,6 +82,9 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
if err != nil {
return err
}
if !httpguts.ValidHostHeader(host) {
return errors.New("http3: invalid Host header")
}
// http.NewRequest sets this field to HTTP/1.1
isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"

View file

@ -3,6 +3,7 @@ package http3
import (
"bufio"
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
@ -23,6 +24,8 @@ type responseWriter struct {
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
logger utils.Logger
}
@ -53,8 +56,30 @@ func (w *responseWriter) WriteHeader(status int) {
return
}
if status < 100 || status >= 200 {
// http status must be 3 digits
if status < 100 || status > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
}
if status >= 200 {
w.headerWritten = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}
w.status = status
@ -105,6 +130,12 @@ func (w *responseWriter) Write(p []byte) (int, error) {
if !bodyAllowed {
return 0, http.ErrBodyNotAllowed
}
w.numWritten += int64(len(p))
if w.contentLen != 0 && w.numWritten > w.contentLen {
return 0, http.ErrContentLength
}
df := &dataFrame{Length: uint64(len(p))}
w.buf = w.buf[:0]
w.buf = df.Append(w.buf)
@ -114,8 +145,12 @@ func (w *responseWriter) Write(p []byte) (int, error) {
return w.bufferedStr.Write(p)
}
func (w *responseWriter) FlushError() error {
return w.bufferedStr.Flush()
}
func (w *responseWriter) Flush() {
if err := w.bufferedStr.Flush(); err != nil {
if err := w.FlushError(); err != nil {
w.logger.Errorf("could not flush to stream: %s", err.Error())
}
}

View file

@ -62,8 +62,6 @@ func versionToALPN(v protocol.VersionNumber) string {
switch v {
case protocol.Version1, protocol.Version2:
return NextProtoH3
case protocol.VersionDraft29:
return NextProtoH3Draft29
default:
return ""
}
@ -575,14 +573,22 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
}
req, err := requestFromHeaders(hfs)
if err != nil {
// TODO: use the right error code
return newStreamError(ErrCodeGeneralProtocolError, err)
return newStreamError(ErrCodeMessageError, err)
}
connState := conn.ConnectionState().TLS.ConnectionState
connState := conn.ConnectionState().TLS
req.TLS = &connState
req.RemoteAddr = conn.RemoteAddr().String()
body := newRequestBody(newStream(str, onFrameError))
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength)
} else {
httpStr = newStream(str, onFrameError)
}
body := newRequestBody(httpStr)
req.Body = body
if s.logger.Debug() {
@ -596,7 +602,6 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, s.logger)
defer r.Flush()
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
@ -624,10 +629,10 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
return requestError{err: errHijacked}
}
if panicked {
r.WriteHeader(http.StatusInternalServerError)
} else {
// only write response when there is no panic
if !panicked {
r.WriteHeader(http.StatusOK)
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))

View file

@ -2,6 +2,7 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
@ -19,10 +20,9 @@ type StreamID = protocol.StreamID
type VersionNumber = protocol.VersionNumber
const (
// VersionDraft29 is IETF QUIC draft-29
VersionDraft29 = protocol.VersionDraft29
// Version1 is RFC 9000
Version1 = protocol.Version1
// Version2 is RFC 9369
Version2 = protocol.Version2
)
@ -122,6 +122,8 @@ type SendStream interface {
// The Context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() or CancelWrite() is called, or when the peer
// cancels the read-side of their stream.
// The cancellation cause is set to the error that caused the stream to
// close, or `context.Canceled` in case the stream is closed without error.
Context() context.Context
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
@ -178,6 +180,8 @@ type Connection interface {
// The error string will be sent to the peer.
CloseWithError(ApplicationErrorCode, string) error
// Context returns a context that is cancelled when the connection is closed.
// The cancellation cause is set to the error that caused the connection to
// close, or `context.Canceled` in case the listener is closed first.
Context() context.Context
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
@ -186,7 +190,7 @@ type Connection interface {
// SendMessage sends a message as a datagram, as specified in RFC 9221.
SendMessage([]byte) error
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
ReceiveMessage() ([]byte, error)
ReceiveMessage(context.Context) ([]byte, error)
}
// An EarlyConnection is a connection that is handshaking.
@ -337,12 +341,14 @@ type ClientHelloInfo struct {
// ConnectionState records basic details about a QUIC connection
type ConnectionState struct {
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS handshake.ConnectionState
TLS tls.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendMessage and ReceiveMessage methods on the Connection.
SupportsDatagrams bool
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version VersionNumber
}

View file

@ -47,6 +47,11 @@ func (h *receivedPacketHandler) ReceivedPacket(
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.EncryptionHandshake:
// The Handshake packet number space might already have been dropped as a result
// of processing the CRYPTO frame that was contained in this packet.
if h.handshakePackets == nil {
return nil
}
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {

View file

@ -136,16 +136,6 @@ func newSentPacketHandler(
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial {
// This function is called when the crypto setup seals a Handshake packet.
// If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space
// before SentPacket() was called for that Initial packet.
return
}
h.dropPackets(encLevel)
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
@ -156,7 +146,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
}
}
func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
@ -165,6 +155,10 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel)
// We might already have dropped this packet number space.
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) (bool, error) {
h.removeFromBytesInFlight(p)
return true, nil
@ -238,10 +232,6 @@ func (h *sentPacketHandler) SentPacket(
isPathMTUProbePacket bool,
) {
h.bytesSent += size
// For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial)
}
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
@ -884,6 +874,12 @@ func (h *sentPacketHandler) ResetForRetry() error {
}
func (h *sentPacketHandler) SetHandshakeConfirmed() {
if h.initialPackets != nil {
panic("didn't drop initial correctly")
}
if h.handshakePackets != nil {
panic("didn't drop handshake correctly")
}
h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary.

View file

@ -5,11 +5,10 @@ import (
"encoding/binary"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
)
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {
@ -93,69 +92,3 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}
type handshakeSealer struct {
LongHeaderSealer
dropInitialKeys func()
dropped bool
}
func newHandshakeSealer(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderSealer {
sealer := newLongHeaderSealer(aead, headerProtector)
// The client drops Initial keys when sending the first Handshake packet.
if perspective == protocol.PerspectiveServer {
return sealer
}
return &handshakeSealer{
LongHeaderSealer: sealer,
dropInitialKeys: dropInitialKeys,
}
}
func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
data := s.LongHeaderSealer.Seal(dst, src, pn, ad)
if !s.dropped {
s.dropInitialKeys()
s.dropped = true
}
return data
}
type handshakeOpener struct {
LongHeaderOpener
dropInitialKeys func()
dropped bool
}
func newHandshakeOpener(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderOpener {
opener := newLongHeaderOpener(aead, headerProtector)
// The server drops Initial keys when first successfully processing a Handshake packet.
if perspective == protocol.PerspectiveClient {
return opener
}
return &handshakeOpener{
LongHeaderOpener: opener,
dropInitialKeys: dropInitialKeys,
}
}
func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad)
if err == nil && !o.dropped {
o.dropInitialKeys()
o.dropped = true
}
return dec, err
}

View file

@ -0,0 +1,104 @@
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// These cipher suite implementations are copied from the standard library crypto/tls package.
const aeadNonceLength = 12
type cipherSuite struct {
ID uint16
Hash crypto.Hash
KeyLen int
AEAD func(key, nonceMask []byte) cipher.AEAD
}
func (s cipherSuite) IVLen() int { return aeadNonceLength }
func getCipherSuite(id uint16) *cipherSuite {
switch id {
case tls.TLS_AES_128_GCM_SHA256:
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
case tls.TLS_CHACHA20_POLY1305_SHA256:
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
case tls.TLS_AES_256_GCM_SHA384:
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13}
default:
panic(fmt.Sprintf("unknown cypher suite: %d", id))
}
}
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}

View file

@ -6,10 +6,8 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
@ -25,98 +23,21 @@ type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
}
var _ net.Conn = &conn{}
func newConn(local, remote net.Addr) net.Conn {
return &conn{
localAddr: local,
remoteAddr: remote,
}
}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type cryptoSetup struct {
tlsConf *tls.Config
extraConf *qtls.ExtraConfig
conn *qtls.Conn
tlsConf *tls.Config
conn *qtls.QUICConn
events []Event
version protocol.VersionNumber
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
paramsChan <-chan []byte
runner handshakeRunner
alertChan chan uint8
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// is closed when Close() is called
closeChan chan struct{}
zeroRTTParameters *wire.TransportParameters
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT bool
zeroRTTParameters *wire.TransportParameters
allow0RTT bool
rttStats *utils.RTTStats
@ -129,73 +50,61 @@ type cryptoSetup struct {
handshakeCompleteTime time.Time
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialStream io.Writer
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeStream io.Writer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
used0RTT atomic.Bool
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var (
_ qtls.RecordLayer = &cryptoSetup{}
_ CryptoSetup = &cryptoSetup{}
)
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.conn = qtls.QUICClient(quicConf)
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
allow0RTT bool,
rttStats *utils.RTTStats,
@ -203,88 +112,52 @@ func NewCryptoSetupServer(
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
cs := newCryptoSetup(
connID,
tp,
runner,
tlsConf,
allow0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
cs.tlsConf = quicConf.TLSConfig
cs.conn = qtls.QUICServer(quicConf)
return cs
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version)
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
allow0RTT: enable0RTT,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan struct{}),
zeroRTTParametersChan: zeroRTTParametersChan,
messageChan: make(chan []byte, 1),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
version: version,
return &cryptoSetup{
initialSealer: initialSealer,
initialOpener: initialOpener,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
events: make([]Event, 0, 16),
ourParams: tp,
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
version: version,
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = math.MaxUint32
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
AlternativeRecordLayer: cs,
EnforceNextProtoSelection: true,
MaxEarlyData: maxEarlyData,
Accept0RTT: cs.accept0RTT,
Rejected0RTT: cs.rejected0RTT,
Enable0RTT: enable0RTT,
GetAppDataForSessionState: cs.marshalDataForSessionState,
SetAppDataFromSessionState: cs.handleDataFromSessionState,
}
return cs, zeroRTTParametersChan
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
@ -301,142 +174,109 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) RunHandshake() {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil {
handshakeErrChan <- err
return
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return wrapError(err)
}
close(handshakeComplete)
}()
if done {
break
}
}
if h.perspective == protocol.PerspectiveClient {
select {
case err := <-handshakeErrChan:
h.onError(0, err.Error())
return
case <-h.clientHelloWrittenChan:
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
} else {
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
}
}
select {
case <-handshakeComplete: // return when the handshake is done
h.mutex.Lock()
h.handshakeCompleteTime = time.Now()
h.mutex.Unlock()
h.runner.OnHandshakeComplete()
case <-h.closeChan:
// wait until the Handshake() go routine has returned
<-h.handshakeDone
case alert := <-h.alertChan:
handshakeErr := <-handshakeErrChan
h.onError(alert, handshakeErr.Error())
}
}
func (h *cryptoSetup) onError(alert uint8, message string) {
var err error
if alert == 0 {
err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message}
} else {
err = qerr.NewLocalCryptoError(alert, message)
}
h.runner.OnError(err)
return nil
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
// It must only be called once.
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
return h.conn.Close()
}
// handleMessage handles a TLS handshake message.
// HandleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.onError(alertUnexpectedMessage, err.Error())
return false
}
if encLevel != protocol.Encryption1RTT {
select {
case h.messageChan <- data:
case <-h.handshakeDone: // handshake errored, nobody is going to consume this message
return false
}
}
if encLevel == protocol.Encryption1RTT {
h.messageChan <- data
h.handlePostHandshakeMessage()
return false
}
readLoop:
for {
select {
case data := <-h.paramsChan:
if data == nil {
h.onError(0x6d, "missing quic_transport_parameters extension")
} else {
h.handleTransportParameters(data)
}
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
case <-h.closeChan:
break readLoop
}
}
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello, typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
case typeNewSessionTicket:
expected = protocol.Encryption1RTT
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.handleMessage(data, encLevel); err != nil {
return wrapError(err)
}
return nil
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
switch ev.Kind {
case qtls.QUICNoEvent:
return true, nil
case qtls.QUICSetReadSecret:
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICSetWriteSecret:
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
case qtls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return false, nil
case qtls.QUICRejectedEarlyData:
h.rejected0RTT()
return false, nil
case qtls.QUICWriteData:
h.WriteRecord(ev.Level, ev.Data)
return false, nil
case qtls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
default:
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
}
}
func (h *cryptoSetup) NextEvent() Event {
if len(h.events) == 0 {
return Event{Kind: EventNoEvent}
}
ev := h.events[0]
h.events = h.events[1:]
return ev
}
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
return err
}
h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
return nil
}
// must be called after receiving the transport parameters
@ -477,17 +317,32 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
return &tp, nil
}
// only valid for the server
func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
}
// GetSessionTicket generates a new session ticket.
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.extraConf.MaxEarlyData > 0 {
appData = (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
if h.tlsConf.SessionTicketsDisabled {
return nil, nil
}
return h.conn.GetSessionTicket(appData)
if err := h.conn.SendSessionTicket(h.allow0RTT); err != nil {
return nil, err
}
ev := h.conn.NextEvent()
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
panic("crypto/tls bug: where's my session ticket?")
}
ticket := ev.Data
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
panic("crypto/tls bug: why more than one ticket?")
}
return ticket, nil
}
// accept0RTT is called for the server when receiving the client's session ticket.
@ -522,64 +377,16 @@ func (h *cryptoSetup) rejected0RTT() {
h.mutex.Unlock()
if had0RTTKeys {
h.runner.DropKeys(protocol.Encryption0RTT)
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
}
}
func (h *cryptoSetup) handlePostHandshakeMessage() {
// make sure the handshake has already completed
<-h.handshakeDone
done := make(chan struct{})
defer close(done)
// h.alertChan is an unbuffered channel.
// If an error occurs during conn.HandlePostHandshakeMessage,
// it will be sent on this channel.
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
case <-done:
}
}()
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
select {
case <-h.closeChan:
case alert := <-alertChan:
h.onError(alert, err.Error())
}
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
select {
case h.isReadingHandshakeMessage <- struct{}{}:
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
select {
case msg := <-h.messageChan:
return msg, nil
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
@ -587,27 +394,19 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.mutex.Unlock()
h.used0RTT.Store(true)
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
}
return
case qtls.EncryptionHandshake:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newHandshakeOpener(
case qtls.QUICEncryptionLevelHandshake:
h.handshakeOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
h.dropInitialKeys,
h.perspective,
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.EncryptionApplication:
h.readEncLevel = protocol.Encryption1RTT
case qtls.QUICEncryptionLevelApplication:
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
if h.logger.Debug() {
@ -617,15 +416,18 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
@ -640,26 +442,25 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
// don't set used0RTT here. 0-RTT might still get rejected.
return
case qtls.EncryptionHandshake:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newHandshakeSealer(
case qtls.QUICEncryptionLevelHandshake:
h.handshakeSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
h.dropInitialKeys,
h.perspective,
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.EncryptionApplication:
h.writeEncLevel = protocol.Encryption1RTT
case qtls.QUICEncryptionLevelApplication:
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.zeroRTTSealer != nil {
// Once we receive handshake keys, we know that 0-RTT was not rejected.
h.used0RTT.Store(true)
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
@ -671,55 +472,39 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
close(h.clientHelloWrittenChan)
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.zeroRTTParametersChan <- h.zeroRTTParameters
} else {
h.logger.Debugf("Not doing 0-RTT.")
h.zeroRTTParametersChan <- nil
}
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
switch encLevel {
case qtls.QUICEncryptionLevelInitial:
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
case qtls.QUICEncryptionLevelHandshake:
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
case qtls.QUICEncryptionLevelApplication:
panic("unexpected write")
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
}
}
func (h *cryptoSetup) SendAlert(alert uint8) {
select {
case h.alertChan <- alert:
case <-h.closeChan:
// no need to send an alert when we've already closed
}
}
// used a callback in the handshakeSealer and handshakeOpener
func (h *cryptoSetup) dropInitialKeys() {
func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
h.runner.DropKeys(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
if dropped {
h.logger.Debugf("Dropping Initial keys.")
}
}
func (h *cryptoSetup) handshakeComplete() {
h.handshakeCompleteTime = time.Now()
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
@ -734,7 +519,6 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
}
h.mutex.Unlock()
if dropped {
h.runner.DropKeys(protocol.EncryptionHandshake)
h.logger.Debugf("Dropping Handshake keys.")
}
}
@ -839,5 +623,15 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return qtls.GetConnectionState(h.conn)
return ConnectionState{
ConnectionState: h.conn.ConnectionState(),
Used0RTT: h.used0RTT.Load(),
}
}
func wrapError(err error) error {
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
return qerr.NewLocalCryptoError(uint8(alertErr), err.Error())
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
}

View file

@ -10,7 +10,6 @@ import (
"golang.org/x/crypto/chacha20"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
type headerProtector interface {
@ -25,7 +24,7 @@ func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
return "quic hp"
}
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
@ -45,7 +44,7 @@ type aesHeaderProtector struct {
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
@ -90,7 +89,7 @@ type chachaHeaderProtector struct {
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
p := &chachaHeaderProtector{

View file

@ -7,13 +7,11 @@ import (
"golang.org/x/crypto/hkdf"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
var (
quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
)
const (
@ -27,18 +25,10 @@ func getSalt(v protocol.VersionNumber) []byte {
if v == protocol.Version2 {
return quicSaltV2
}
if v == protocol.Version1 {
return quicSaltV1
}
return quicSaltOld
return quicSaltV1
}
var initialSuite = &qtls.CipherSuiteTLS13{
ID: tls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,
}
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
@ -54,8 +44,8 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
encrypter := initialSuite.AEAD(myKey, myIV)
decrypter := initialSuite.AEAD(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))

View file

@ -1,12 +1,12 @@
package handshake
import (
"crypto/tls"
"errors"
"io"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/wire"
)
@ -22,9 +22,6 @@ var (
ErrDecryptionFailed = errors.New("decryption failed")
)
// ConnectionState contains information about the state of the connection.
type ConnectionState = qtls.ConnectionState
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
@ -56,29 +53,54 @@ type ShortHeaderSealer interface {
KeyPhase() protocol.KeyPhaseBit
}
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
TransportParameters() <-chan []byte
type ConnectionState struct {
tls.ConnectionState
Used0RTT bool
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
DropKeys(protocol.EncryptionLevel)
// EventKind is the kind of handshake event.
type EventKind uint8
const (
// EventNoEvent signals that there are no new handshake events
EventNoEvent EventKind = iota + 1
// EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level
EventWriteInitialData
// EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level
EventWriteHandshakeData
// EventReceivedReadKeys signals that new decryption keys are available.
// It doesn't say which encryption level those keys are for.
EventReceivedReadKeys
// EventDiscard0RTTKeys signals that the Handshake keys were discarded.
EventDiscard0RTTKeys
// EventReceivedTransportParameters contains the transport parameters sent by the peer.
EventReceivedTransportParameters
// EventRestoredTransportParameters contains the transport parameters restored from the session ticket.
// It is only used for the client.
EventRestoredTransportParameters
// EventHandshakeComplete signals that the TLS handshake was completed.
EventHandshakeComplete
)
// Event is a handshake event.
type Event struct {
Kind EventKind
Data []byte
TransportParameters *wire.TransportParameters
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake()
StartHandshake() error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() Event
SetLargest1RTTAcked(protocol.PacketNumber) error
DiscardInitialKeys()
SetHandshakeConfirmed()
ConnectionState() ConnectionState

View file

@ -1,6 +0,0 @@
//go:build gomock || generate
package handshake
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package handshake -destination mock_handshake_runner_test.go github.com/quic-go/quic-go/internal/handshake HandshakeRunner"
type HandshakeRunner = handshakeRunner

View file

@ -11,13 +11,11 @@ import (
)
var (
retryAEADdraft29 cipher.AEAD // used for QUIC draft versions up to 34
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
retryAEADv2 cipher.AEAD // used for QUIC v2
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
)
func init() {
retryAEADdraft29 = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1})
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
}
@ -35,11 +33,10 @@ func initAEAD(key [16]byte) cipher.AEAD {
}
var (
retryBuf bytes.Buffer
retryMutex sync.Mutex
retryNonceDraft29 = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c}
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
retryBuf bytes.Buffer
retryMutex sync.Mutex
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
)
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
@ -54,14 +51,10 @@ func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, ve
var tag [16]byte
var sealed []byte
//nolint:exhaustive // These are all the versions we support
switch version {
case protocol.Version1:
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
case protocol.Version2:
if version == protocol.Version2 {
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
default:
sealed = retryAEADdraft29.Seal(tag[:0], retryNonceDraft29[:], nil, retryBuf.Bytes())
} else {
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
}
if len(sealed) != 16 {
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))

View file

@ -10,7 +10,7 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
const sessionTicketRevision = 2
const sessionTicketRevision = 3
type sessionTicket struct {
Parameters *wire.TransportParameters

View file

@ -1,68 +0,0 @@
package handshake
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
)
const (
quicTLSExtensionTypeOldDrafts = 0xffa5
quicTLSExtensionType = 0x39
)
type extensionHandler struct {
ourParams []byte
paramsChan chan []byte
extensionType uint16
perspective protocol.Perspective
}
var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler {
et := uint16(quicTLSExtensionType)
if v == protocol.VersionDraft29 {
et = quicTLSExtensionTypeOldDrafts
}
return &extensionHandler{
ourParams: params,
paramsChan: make(chan []byte),
perspective: pers,
extensionType: et,
}
}
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
return nil
}
return []qtls.Extension{{
Type: h.extensionType,
Data: h.ourParams,
}}
}
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
return
}
var data []byte
for _, ext := range exts {
if ext.Type == h.extensionType {
data = ext.Data
break
}
}
h.paramsChan <- data
}
func (h *extensionHandler) TransportParameters() <-chan []byte {
return h.paramsChan
}

View file

@ -10,7 +10,6 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
@ -24,7 +23,7 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
var FirstKeyUpdateInterval uint64 = 100
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
suite *cipherSuite
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
@ -121,7 +120,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
// SetReadKey sets the read key.
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
@ -135,7 +134,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
@ -146,7 +145,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
}
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
a.nonceBuf = make([]byte, aead.NonceSize())
a.aeadOverhead = aead.Overhead()
a.suite = suite

View file

@ -19,14 +19,14 @@ const (
// The version numbers, making grepping easier
const (
VersionUnknown VersionNumber = math.MaxUint32
VersionDraft29 VersionNumber = 0xff00001d
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
Version1 VersionNumber = 0x1
Version2 VersionNumber = 0x6b3343cf
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29}
var SupportedVersions = []VersionNumber{Version1, Version2}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
@ -38,7 +38,7 @@ func (vn VersionNumber) String() string {
switch vn {
case VersionUnknown:
return "unknown"
case VersionDraft29:
case versionDraft29:
return "draft-29"
case Version1:
return "v1"

View file

@ -40,7 +40,7 @@ func (e TransportErrorCode) Message() string {
if !e.IsCryptoError() {
return ""
}
return qtls.Alert(e - 0x100).Error()
return qtls.AlertError(e - 0x100).Error()
}
func (e TransportErrorCode) String() string {

View file

@ -0,0 +1,66 @@
//go:build go1.21
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"unsafe"
)
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View file

@ -0,0 +1,61 @@
//go:build go1.21
package qtls
import (
"crypto/tls"
)
type clientSessionCache struct {
getData func() []byte
setData func([]byte)
wrapped tls.ClientSessionCache
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
if cs == nil {
c.wrapped.Put(key, nil)
return
}
ticket, state, err := cs.ResumptionState()
if err != nil || state == nil {
c.wrapped.Put(key, cs)
return
}
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
newCS, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error. Just save the original state.
c.wrapped.Put(key, cs)
return
}
c.wrapped.Put(key, newCS)
}
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
cs, ok := c.wrapped.Get(key)
if !ok || cs == nil {
return cs, ok
}
ticket, state, err := cs.ResumptionState()
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
// restore QUIC transport parameters and RTT stored in state.Extra
if extra := findExtraData(state.Extra); extra != nil {
c.setData(extra)
}
session, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
return session, true
}

View file

@ -1,145 +0,0 @@
//go:build go1.19 && !go1.20
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"net"
"unsafe"
"github.com/quic-go/qtls-go1-19"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains information about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}
//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View file

@ -1,101 +1,97 @@
//go:build go1.20
//go:build go1.20 && !go1.21
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"fmt"
"net"
"unsafe"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/qtls-go1-20"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains information about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
QUICConn = qtls.QUICConn
QUICConfig = qtls.QUICConfig
QUICEvent = qtls.QUICEvent
QUICEventKind = qtls.QUICEventKind
QUICEncryptionLevel = qtls.QUICEncryptionLevel
AlertError = qtls.AlertError
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial
QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
const (
QUICNoEvent = qtls.QUICNoEvent
QUICSetReadSecret = qtls.QUICSetReadSecret
QUICSetWriteSecret = qtls.QUICSetWriteSecret
QUICWriteData = qtls.QUICWriteData
QUICTransportParameters = qtls.QUICTransportParameters
QUICTransportParametersRequired = qtls.QUICTransportParametersRequired
QUICRejectedEarlyData = qtls.QUICRejectedEarlyData
QUICHandshakeDone = qtls.QUICHandshakeDone
)
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) {
qtls.InitSessionTicketKeys(conf.TLSConfig)
conf.TLSConfig = conf.TLSConfig.Clone()
conf.TLSConfig.MinVersion = tls.VersionTLS13
conf.ExtraConfig = &qtls.ExtraConfig{
Enable0RTT: enable0RTT,
Accept0RTT: accept0RTT,
GetAppDataForSessionTicket: getDataForSessionTicket,
}
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) {
conf.ExtraConfig = &qtls.ExtraConfig{
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
}
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
func QUICServer(config *QUICConfig) *QUICConn {
return qtls.QUICServer(config)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
func QUICClient(config *QUICConfig) *QUICConn {
return qtls.QUICClient(config)
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return qtls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return qtls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return qtls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return qtls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case qtls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case qtls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case qtls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case qtls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}

View file

@ -2,4 +2,153 @@
package qtls
var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
)
type (
QUICConn = tls.QUICConn
QUICConfig = tls.QUICConfig
QUICEvent = tls.QUICEvent
QUICEventKind = tls.QUICEventKind
QUICEncryptionLevel = tls.QUICEncryptionLevel
AlertError = tls.AlertError
)
const (
QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial
QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication
)
const (
QUICNoEvent = tls.QUICNoEvent
QUICSetReadSecret = tls.QUICSetReadSecret
QUICSetWriteSecret = tls.QUICSetWriteSecret
QUICWriteData = tls.QUICWriteData
QUICTransportParameters = tls.QUICTransportParameters
QUICTransportParametersRequired = tls.QUICTransportParametersRequired
QUICRejectedEarlyData = tls.QUICRejectedEarlyData
QUICHandshakeDone = tls.QUICHandshakeDone
)
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) {
conf := qconf.TLSConfig
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC transport parameters if this is a 0-RTT packet.
// TODO(#3853): also save the RTT for non-0-RTT tickets
if state.EarlyData {
state.Extra = append(state.Extra, addExtraPrefix(getData()))
}
if origWrapSession != nil {
return origWrapSession(cs, state)
}
b, err := conf.EncryptTicket(cs, state)
return b, err
}
origUnwrapSession := conf.UnwrapSession
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
// However, using 0-RTT is only possible with the first session ticket.
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
var unwrapCount int
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
unwrapCount++
var state *tls.SessionState
var err error
if origUnwrapSession != nil {
state, err = origUnwrapSession(identity, connState)
} else {
state, err = conf.DecryptTicket(identity, connState)
}
if err != nil || state == nil {
return nil, err
}
if state.EarlyData {
extra := findExtraData(state.Extra)
if unwrapCount == 1 && extra != nil { // first session ticket
state.EarlyData = accept0RTT(extra)
} else { // subsequent session ticket, can't be used for 0-RTT
state.EarlyData = false
}
}
return state, nil
}
}
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) {
conf := qconf.TLSConfig
if conf.ClientSessionCache != nil {
origCache := conf.ClientSessionCache
conf.ClientSessionCache = &clientSessionCache{
wrapped: origCache,
getData: getData,
setData: setData,
}
}
}
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}
const extraPrefix = "quic-go1"
func addExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View file

@ -1,4 +1,4 @@
//go:build !go1.19
//go:build !go1.20
package qtls

View file

@ -108,7 +108,7 @@ func Is0RTTPacket(b []byte) bool {
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
//nolint:exhaustive // We only need to test QUIC versions that we support.
switch version {
case protocol.Version1, protocol.VersionDraft29:
case protocol.Version1:
return b[0]>>4&0b11 == 0b01
case protocol.Version2:
return b[0]>>4&0b11 == 0b10

42
vendor/github.com/quic-go/quic-go/oss-fuzz.sh generated vendored Normal file
View file

@ -0,0 +1,42 @@
#!/bin/bash
# Install Go manually, since oss-fuzz ships with an outdated Go version.
# See https://github.com/google/oss-fuzz/pull/10643.
export CXX="${CXX} -lresolv" # required by Go 1.20
wget https://go.dev/dl/go1.20.5.linux-amd64.tar.gz \
&& mkdir temp-go \
&& rm -rf /root/.go/* \
&& tar -C temp-go/ -xzf go1.20.5.linux-amd64.tar.gz \
&& mv temp-go/go/* /root/.go/ \
&& rm -rf temp-go go1.20.5.linux-amd64.tar.gz
(
# fuzz qpack
compile_go_fuzzer github.com/quic-go/qpack/fuzzing Fuzz qpack_fuzzer
)
(
# fuzz quic-go
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/frames Fuzz frame_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/header Fuzz header_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/transportparameters Fuzz transportparameter_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/tokens Fuzz token_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/handshake Fuzz handshake_fuzzer
if [ $SANITIZER == "coverage" ]; then
# no need for corpora if coverage
exit 0
fi
# generate seed corpora
cd $GOPATH/src/github.com/quic-go/quic-go/
go generate -x ./fuzzing/...
zip --quiet -r $OUT/header_fuzzer_seed_corpus.zip fuzzing/header/corpus
zip --quiet -r $OUT/frame_fuzzer_seed_corpus.zip fuzzing/frames/corpus
zip --quiet -r $OUT/transportparameter_fuzzer_seed_corpus.zip fuzzing/transportparameters/corpus
zip --quiet -r $OUT/handshake_fuzzer_seed_corpus.zip fuzzing/handshake/corpus
)
# for debugging
ls -al $OUT

View file

@ -82,7 +82,7 @@ func (h *sendQueue) Run() error {
// 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and
// 3. Eventual detection of loss PingFrame.
if !isMsgSizeErr(err) {
if !isSendMsgSizeErr(err) {
return err
}
}

View file

@ -30,7 +30,7 @@ type sendStream struct {
retransmissionQueue []*wire.StreamFrame
ctx context.Context
ctxCancel context.CancelFunc
ctxCancel context.CancelCauseFunc
streamID protocol.StreamID
sender streamSender
@ -71,7 +71,7 @@ func newSendStream(
writeChan: make(chan struct{}, 1),
writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
return s
}
@ -366,7 +366,7 @@ func (s *sendStream) Close() error {
s.mutex.Unlock()
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.ctxCancel()
s.ctxCancel(nil)
s.finishedWriting = true
s.mutex.Unlock()
@ -385,8 +385,8 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
s.mutex.Unlock()
return
}
s.ctxCancel()
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.cancelWriteErr)
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
newlyCompleted := s.isNewlyCompleted()
@ -435,7 +435,7 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.ctxCancel()
s.ctxCancel(err)
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()

View file

@ -2,7 +2,11 @@ package quic
import (
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"syscall"
"time"
@ -23,27 +27,28 @@ type OOBCapablePacketConn interface {
var _ OOBCapablePacketConn = &net.UDPConn{}
// OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance:
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does).
//
// It's only necessary to call this function explicitly if the application calls WriteTo
// after passing the connection to the Transport.
func OptimizeConn(c net.PacketConn) (net.PacketConn, error) {
return wrapConn(c)
}
func wrapConn(pc net.PacketConn) (rawConn, error) {
if err := setReceiveBuffer(pc); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
if err := setSendBuffer(pc); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
func wrapConn(pc net.PacketConn) (interface {
net.PacketConn
rawConn
}, error,
) {
conn, ok := pc.(interface {
SyscallConn() (syscall.RawConn, error)
})

View file

@ -11,7 +11,7 @@ import (
)
//go:generate sh -c "echo '// Code generated by go generate. DO NOT EDIT.\n// Source: sys_conn_buffers.go\n' > sys_conn_buffers_write.go && sed -e 's/SetReadBuffer/SetWriteBuffer/g' -e 's/setReceiveBuffer/setSendBuffer/g' -e 's/inspectReadBuffer/inspectWriteBuffer/g' -e 's/protocol\\.DesiredReceiveBufferSize/protocol\\.DesiredSendBufferSize/g' -e 's/forceSetReceiveBuffer/forceSetSendBuffer/g' -e 's/receive buffer/send buffer/g' sys_conn_buffers.go | sed '/^\\/\\/go:generate/d' >> sys_conn_buffers_write.go"
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
func setReceiveBuffer(c net.PacketConn) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
@ -40,7 +40,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
utils.DefaultLogger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
@ -63,6 +63,6 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
utils.DefaultLogger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}

View file

@ -13,7 +13,7 @@ import (
"github.com/quic-go/quic-go/internal/utils"
)
func setSendBuffer(c net.PacketConn, logger utils.Logger) error {
func setSendBuffer(c net.PacketConn) error {
conn, ok := c.(interface{ SetWriteBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of send buffer size. Not a *net.UDPConn?")
@ -42,7 +42,7 @@ func setSendBuffer(c net.PacketConn, logger utils.Logger) error {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
if size >= protocol.DesiredSendBufferSize {
logger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024)
utils.DefaultLogger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
@ -65,6 +65,6 @@ func setSendBuffer(c net.PacketConn, logger utils.Logger) error {
if newSize < protocol.DesiredSendBufferSize {
return fmt.Errorf("failed to sufficiently increase send buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased send buffer size to %d kiB", newSize/1024)
utils.DefaultLogger.Debugf("Increased send buffer size to %d kiB", newSize/1024)
return nil
}

View file

@ -1,4 +1,4 @@
//go:build !linux && !windows
//go:build !linux && !windows && !darwin
package quic
@ -11,7 +11,12 @@ func setDF(syscall.RawConn) (bool, error) {
return false, nil
}
func isMsgSizeErr(err error) bool {
func isSendMsgSizeErr(err error) bool {
// to be implemented for more specific platforms
return false
}
func isRecvMsgSizeErr(err error) bool {
// to be implemented for more specific platforms
return false
}

View file

@ -0,0 +1,74 @@
//go:build darwin
package quic
import (
"errors"
"strconv"
"strings"
"syscall"
"golang.org/x/sys/unix"
"github.com/quic-go/quic-go/internal/utils"
)
func setDF(rawConn syscall.RawConn) (bool, error) {
// Setting DF bit is only supported from macOS11
// https://github.com/chromium/chromium/blob/117.0.5881.2/net/socket/udp_socket_posix.cc#L555
if supportsDF, err := isAtLeastMacOS11(); !supportsDF || err != nil {
return false, err
}
// Enabling IP_DONTFRAG will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
}); err != nil {
return false, err
}
switch {
case errDFIPv4 == nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.")
case errDFIPv4 == nil && errDFIPv6 != nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4.")
case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.")
// On macOS, the syscall for setting DF bit for IPv4 fails on dual-stack listeners.
// Treat the connection as not having DF enabled, even though the DF bit will be set
// when used for IPv6.
// See https://github.com/quic-go/quic-go/issues/3793 for details.
return false, nil
case errDFIPv4 != nil && errDFIPv6 != nil:
return false, errors.New("setting DF failed for both IPv4 and IPv6")
}
return true, nil
}
func isSendMsgSizeErr(err error) bool {
return errors.Is(err, unix.EMSGSIZE)
}
func isRecvMsgSizeErr(error) bool { return false }
func isAtLeastMacOS11() (bool, error) {
uname := &unix.Utsname{}
err := unix.Uname(uname)
if err != nil {
return false, err
}
release := string(uname.Release[:])
if idx := strings.Index(release, "."); idx != -1 {
version, err := strconv.Atoi(release[:idx])
if err != nil {
return false, err
}
// Darwin version 20 is macOS version 11
// https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards
return version >= 20, nil
}
return false, nil
}

View file

@ -15,11 +15,6 @@ import (
"github.com/quic-go/quic-go/internal/utils"
)
// UDP_SEGMENT controls GSO (Generic Segmentation Offload)
//
//nolint:stylecheck
const UDP_SEGMENT = 103
func setDF(rawConn syscall.RawConn) (bool, error) {
// Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented
@ -51,7 +46,7 @@ func maybeSetGSO(rawConn syscall.RawConn) bool {
var setErr error
if err := rawConn.Control(func(fd uintptr) {
setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, UDP_SEGMENT, 1)
setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, unix.UDP_SEGMENT, 1)
}); err != nil {
setErr = err
}
@ -62,18 +57,20 @@ func maybeSetGSO(rawConn syscall.RawConn) bool {
return true
}
func isMsgSizeErr(err error) bool {
func isSendMsgSizeErr(err error) bool {
// https://man7.org/linux/man-pages/man7/udp.7.html
return errors.Is(err, unix.EMSGSIZE)
}
func isRecvMsgSizeErr(err error) bool { return false }
func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte {
startLen := len(b)
const dataLen = 2 // payload is a uint16
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_UDP
h.Type = UDP_SEGMENT
h.Type = unix.UDP_SEGMENT
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.

View file

@ -43,7 +43,12 @@ func setDF(rawConn syscall.RawConn) (bool, error) {
return true, nil
}
func isMsgSizeErr(err error) bool {
func isSendMsgSizeErr(err error) bool {
// https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
return errors.Is(err, windows.WSAEMSGSIZE)
}
func isRecvMsgSizeErr(err error) bool {
// https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
return errors.Is(err, windows.WSAEMSGSIZE)
}

View file

@ -2,20 +2,30 @@
package quic
import "golang.org/x/sys/unix"
import (
"encoding/binary"
"net/netip"
const msgTypeIPTOS = unix.IP_RECVTOS
const (
ipv4RECVPKTINFO = unix.IP_RECVPKTINFO
ipv6RECVPKTINFO = 0x3d
"golang.org/x/sys/unix"
)
const (
msgTypeIPv4PKTINFO = unix.IP_PKTINFO
msgTypeIPv6PKTINFO = 0x2e
msgTypeIPTOS = unix.IP_RECVTOS
ipv4PKTINFO = unix.IP_RECVPKTINFO
)
// ReadBatch only returns a single packet on OSX,
// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch.
const batchSize = 1
func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
if len(body) != 12 {
return netip.Addr{}, 0, false
}
return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true
}

View file

@ -2,20 +2,25 @@
package quic
import "golang.org/x/sys/unix"
import (
"net/netip"
"golang.org/x/sys/unix"
)
const (
msgTypeIPTOS = unix.IP_RECVTOS
)
const (
ipv4RECVPKTINFO = 0x7
ipv6RECVPKTINFO = 0x24
)
const (
msgTypeIPv4PKTINFO = 0x7
msgTypeIPv6PKTINFO = 0x2e
ipv4PKTINFO = 0x7
)
const batchSize = 8
func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) {
// struct in_pktinfo {
// struct in_addr ipi_addr; /* Header Destination address */
// };
if len(body) != 4 {
return netip.Addr{}, 0, false
}
return netip.AddrFrom4(*(*[4]byte)(body)), 0, true
}

View file

@ -3,21 +3,16 @@
package quic
import (
"encoding/binary"
"net/netip"
"syscall"
"golang.org/x/sys/unix"
)
const msgTypeIPTOS = unix.IP_TOS
const (
ipv4RECVPKTINFO = unix.IP_PKTINFO
ipv6RECVPKTINFO = unix.IPV6_RECVPKTINFO
)
const (
msgTypeIPv4PKTINFO = unix.IP_PKTINFO
msgTypeIPv6PKTINFO = unix.IPV6_PKTINFO
msgTypeIPTOS = unix.IP_TOS
ipv4PKTINFO = unix.IP_PKTINFO
)
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
@ -41,3 +36,15 @@ func forceSetSendBuffer(c syscall.RawConn, bytes int) error {
}
return serr
}
func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
if len(body) != 12 {
return netip.Addr{}, 0, false
}
return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true
}

View file

@ -6,8 +6,10 @@ import (
"encoding/binary"
"errors"
"fmt"
"log"
"net"
"net/netip"
"sync"
"syscall"
"time"
@ -87,8 +89,8 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if needsPacketInfo {
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4PKTINFO, 1)
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
}
}); err != nil {
return nil, err
@ -149,6 +151,8 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
return oobConn, nil
}
var invalidCmsgOnceV4, invalidCmsgOnceV6 sync.Once
func (c *oobConn) ReadPacket() (receivedPacket, error) {
if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
c.messages = c.messages[:batchSize]
@ -188,38 +192,36 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
switch hdr.Type {
case msgTypeIPTOS:
p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv4PKTINFO:
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination
// address */
// };
var ip [4]byte
if len(body) == 12 {
copy(ip[:], body[8:12])
p.info.ifIndex = binary.LittleEndian.Uint32(body)
} else if len(body) == 4 {
// FreeBSD
copy(ip[:], body)
case ipv4PKTINFO:
ip, ifIndex, ok := parseIPv4PktInfo(body)
if ok {
p.info.addr = ip
p.info.ifIndex = ifIndex
} else {
invalidCmsgOnceV4.Do(func() {
log.Printf("Received invalid IPv4 packet info control message: %+x. "+
"This should never occur, please open a new issue and include details about the architecture.", body)
})
}
p.info.addr = netip.AddrFrom4(ip)
}
}
if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type {
case unix.IPV6_TCLASS:
p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv6PKTINFO:
case unix.IPV6_PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
if len(body) == 20 {
var ip [16]byte
copy(ip[:], body[:16])
p.info.addr = netip.AddrFrom16(ip)
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16]))
p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} else {
invalidCmsgOnceV6.Do(func() {
log.Printf("Received invalid IPv6 packet info control message: %+x. "+
"This should never occur, please open a new issue and include details about the architecture.", body)
})
}
}
}
@ -228,13 +230,6 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
return p, nil
}
// WriteTo (re)implements the net.PacketConn method.
// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return c.WritePacket(p, uint16(len(p)), addr, nil)
}
// WritePacket writes a new packet.
// If the connection supports GSO (and we activated GSO support before),
// it appends the UDP_SEGMENT size message to oob.

View file

@ -5,11 +5,7 @@ import (
"crypto/rand"
"crypto/tls"
"errors"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
@ -30,9 +26,16 @@ type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
// After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
// Calling WriteTo is only valid on the connection returned by OptimizeConn.
// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
// as a *net.UDPConn does.
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
Conn net.PacketConn
// The length of the connection ID in bytes.
@ -103,7 +106,7 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
@ -132,7 +135,7 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
@ -149,13 +152,15 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
@ -165,20 +170,20 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}
func (t *Transport) init(isServer bool) error {
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
@ -205,19 +210,28 @@ func (t *Transport) init(isServer bool) error {
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
getMultiplexer().AddConn(t.Conn)
go t.listen(conn)
go t.runSendQueue()
})
return t.initErr
}
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
return 0, err
}
return t.conn.WritePacket(b, uint16(len(b)), addr, nil)
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p:
@ -299,27 +313,6 @@ func (t *Transport) listen(conn rawConn) {
defer close(t.listening)
defer getMultiplexer().RemoveConn(t.Conn)
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
if err := setSendBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
for {
p, err := conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
@ -337,6 +330,10 @@ func (t *Transport) listen(conn rawConn) {
continue
}
if err != nil {
// Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer.
if isRecvMsgSizeErr(err) {
continue
}
t.close(err)
return
}