Update quic-go

This commit is contained in:
Frank Denis 2023-02-25 23:45:38 +01:00
parent 47e6a56b16
commit 15c87a68a1
55 changed files with 1448 additions and 12574 deletions

View file

@ -12,7 +12,7 @@ In addition to the RFCs listed above, it currently implements the [IETF QUIC dra
## Guides
*We currently support Go 1.18.x and Go 1.19.x.*
*We currently support Go 1.19.x and Go 1.20.x*
Running tests:

View file

@ -1659,6 +1659,7 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters)
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
return
}
s.peerParams = params
// On the client side we have to wait for handshake completion.
@ -2026,6 +2027,21 @@ func (s *connection) logShortHeaderPacket(
func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
s.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {

View file

@ -1,6 +1,8 @@
package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
@ -9,7 +11,10 @@ import (
type datagramQueue struct {
sendQueue chan *wire.DatagramFrame
nextFrame *wire.DatagramFrame
rcvQueue chan []byte
rcvMx sync.Mutex
rcvQueue [][]byte
rcvd chan struct{} // used to notify Receive that a new datagram was received
closeErr error
closed chan struct{}
@ -25,7 +30,7 @@ func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
return &datagramQueue{
hasData: hasData,
sendQueue: make(chan *wire.DatagramFrame, 1),
rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen),
rcvd: make(chan struct{}, 1),
dequeued: make(chan struct{}),
closed: make(chan struct{}),
logger: logger,
@ -76,20 +81,39 @@ func (h *datagramQueue) Pop() {
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
data := make([]byte, len(f.Data))
copy(data, f.Data)
select {
case h.rcvQueue <- data:
default:
var queued bool
h.rcvMx.Lock()
if len(h.rcvQueue) < protocol.DatagramRcvQueueLen {
h.rcvQueue = append(h.rcvQueue, data)
queued = true
select {
case h.rcvd <- struct{}{}:
default:
}
}
h.rcvMx.Unlock()
if !queued && h.logger.Debug() {
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
}
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive() ([]byte, error) {
select {
case data := <-h.rcvQueue:
return data, nil
case <-h.closed:
return nil, h.closeErr
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
data := h.rcvQueue[0]
h.rcvQueue = h.rcvQueue[1:]
h.rcvMx.Unlock()
return data, nil
}
h.rcvMx.Unlock()
select {
case <-h.rcvd:
continue
case <-h.closed:
return nil, h.closeErr
}
}
}

View file

@ -9,6 +9,7 @@ import (
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
@ -63,7 +64,7 @@ type client struct {
decoder *qpack.Decoder
hostname string
conn quic.EarlyConnection
conn atomic.Pointer[quic.EarlyConnection]
logger utils.Logger
}
@ -108,33 +109,35 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
func (c *client) dial(ctx context.Context) error {
var err error
var conn quic.EarlyConnection
if c.dialer != nil {
c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
} else {
c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
c.conn.Store(&conn)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(); err != nil {
if err := c.setupConn(conn); err != nil {
c.logger.Debugf("Setting up connection failed: %s", err)
c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
}
}()
if c.opts.StreamHijacker != nil {
go c.handleBidirectionalStreams()
go c.handleBidirectionalStreams(conn)
}
go c.handleUnidirectionalStreams()
go c.handleUnidirectionalStreams(conn)
return nil
}
func (c *client) setupConn() error {
func (c *client) setupConn(conn quic.EarlyConnection) error {
// open the control stream
str, err := c.conn.OpenUniStream()
str, err := conn.OpenUniStream()
if err != nil {
return err
}
@ -146,16 +149,16 @@ func (c *client) setupConn() error {
return err
}
func (c *client) handleBidirectionalStreams() {
func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := c.conn.AcceptStream(context.Background())
str, err := conn.AcceptStream(context.Background())
if err != nil {
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
return
}
go func(str quic.Stream) {
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str, e)
return c.opts.StreamHijacker(ft, conn, str, e)
})
if err == errHijacked {
return
@ -163,14 +166,14 @@ func (c *client) handleBidirectionalStreams() {
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}(str)
}
}
func (c *client) handleUnidirectionalStreams() {
func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := c.conn.AcceptUniStream(context.Background())
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
return
@ -179,7 +182,7 @@ func (c *client) handleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
@ -194,10 +197,10 @@ func (c *client) handleUnidirectionalStreams() {
return
case streamTypePushStream:
// We never increased the Push ID, so we don't expect any push streams.
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
@ -205,12 +208,12 @@ func (c *client) handleUnidirectionalStreams() {
}
f, err := parseNextFrame(str, nil)
if err != nil {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
return
}
sf, ok := f.(*settingsFrame)
if !ok {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
return
}
if !sf.Datagram {
@ -219,18 +222,19 @@ func (c *client) handleUnidirectionalStreams() {
// If datagram support was enabled on our side as well as on the server side,
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
}
}(str)
}
}
func (c *client) Close() error {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return nil
}
return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
}
func (c *client) maxHeaderBytes() uint64 {
@ -249,24 +253,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
c.dialOnce.Do(func() {
c.handshakeErr = c.dial(req.Context())
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
// At this point, c.conn is guaranteed to be set.
conn := *c.conn.Load()
// Immediately send out this request, if this is a 0-RTT request.
if req.Method == MethodGet0RTT {
req.Method = http.MethodGet
} else {
// wait for the handshake to complete
select {
case <-c.conn.HandshakeComplete().Done():
case <-conn.HandshakeComplete().Done():
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
str, err := c.conn.OpenStreamSync(req.Context())
str, err := conn.OpenStreamSync(req.Context())
if err != nil {
return nil, err
}
@ -290,7 +296,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
if opt.DontCloseRequestStream {
doneChan = nil
}
rsp, rerr := c.doRequest(req, str, opt, doneChan)
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
if rerr.err != nil { // if any error occurred
close(reqDone)
<-done
@ -302,7 +308,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
if rerr.err != nil {
reason = rerr.err.Error()
}
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return nil, rerr.err
}
@ -340,7 +346,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
return nil
}
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
@ -353,7 +359,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
str.Close()
}
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
if req.Body != nil {
// send the request body asynchronously
go func() {
@ -387,7 +393,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
return nil, newConnError(errorGeneralProtocolError, err)
}
connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS)
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
res := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
@ -408,7 +414,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(hstr, c.conn, reqDone)
respBody := newResponseBody(hstr, conn, reqDone)
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
@ -438,11 +444,12 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
}
func (c *client) HandshakeComplete() bool {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return false
}
select {
case <-c.conn.HandshakeComplete().Done():
case <-(*conn).HandshakeComplete().Done():
return true
default:
return false

View file

@ -80,10 +80,26 @@ func (w *responseWriter) WriteHeader(status int) {
}
func (w *responseWriter) Write(p []byte) (int, error) {
bodyAllowed := bodyAllowedForStatus(w.status)
if !w.headerWritten {
// If body is not allowed, we don't need to (and we can't) sniff the content type.
if bodyAllowed {
// If no content type, apply sniffing algorithm to body.
// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
_, haveType := w.header["Content-Type"]
// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
// we shouldn't sniff the body.
hasTE := w.header.Get("Transfer-Encoding") != ""
hasCE := w.header.Get("Content-Encoding") != ""
if !hasCE && !haveType && !hasTE && len(p) > 0 {
w.header.Set("Content-Type", http.DetectContentType(p))
}
}
w.WriteHeader(http.StatusOK)
bodyAllowed = true
}
if !bodyAllowedForStatus(w.status) {
if !bodyAllowed {
return 0, http.ErrBodyNotAllowed
}
df := &dataFrame{Length: uint64(len(p))}

View file

@ -263,7 +263,7 @@ func newCryptoSetup(
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan struct{}),
zeroRTTParametersChan: zeroRTTParametersChan,
messageChan: make(chan []byte, 100),
messageChan: make(chan []byte, 1),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
version: version,
@ -368,8 +368,15 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
h.onError(alertUnexpectedMessage, err.Error())
return false
}
h.messageChan <- data
if encLevel != protocol.Encryption1RTT {
select {
case h.messageChan <- data:
case <-h.handshakeDone: // handshake errored, nobody is going to consume this message
return false
}
}
if encLevel == protocol.Encryption1RTT {
h.messageChan <- data
h.handlePostHandshakeMessage()
return false
}

View file

@ -24,7 +24,7 @@ var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler {
et := uint16(quicTLSExtensionType)
if v != protocol.Version1 {
if v == protocol.VersionDraft29 {
et = quicTLSExtensionTypeOldDrafts
}
return &extensionHandler{

View file

@ -1,99 +0,0 @@
//go:build go1.18 && !go1.19
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"net"
"unsafe"
"github.com/quic-go/qtls-go1-18"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains inforamtion about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-18.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}

View file

@ -1,4 +1,4 @@
//go:build !go1.18
//go:build !go1.19
package qtls

View file

@ -1,22 +0,0 @@
package utils
import "sync/atomic"
// An AtomicBool is an atomic bool
type AtomicBool struct {
v int32
}
// Set sets the value
func (a *AtomicBool) Set(value bool) {
var n int32
if value {
n = 1
}
atomic.StoreInt32(&a.v, n)
}
// Get gets the value
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.v) != 0
}

View file

@ -71,6 +71,7 @@ func Read(r io.ByteReader) (uint64, error) {
}
// Write writes i in the QUIC varint format to w.
// Deprecated: use Append instead.
func Write(w Writer, i uint64) {
if i <= maxVarInt1 {
w.WriteByte(uint8(i))
@ -88,6 +89,7 @@ func Write(w Writer, i uint64) {
}
}
// Append appends i in the QUIC varint format.
func Append(b []byte, i uint64) []byte {
if i <= maxVarInt1 {
return append(b, uint8(i))

View file

@ -34,18 +34,14 @@ type receiveStream struct {
currentFrame []byte
currentFrameDone func()
currentFrameIsLast bool // is the currentFrame the last frame on this stream
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream
finRead bool // set once we read a frame with a Fin
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError
closedForShutdown bool // set when CloseForShutdown() is called
finRead bool // set once we read a frame with a Fin
canceledRead bool // set when CancelRead() is called
resetRemotely bool // set when handleResetStreamFrame() is called
readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
deadline time.Time
@ -100,13 +96,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
if s.finRead {
return false, 0, io.EOF
}
if s.canceledRead {
if s.cancelReadErr != nil {
return false, 0, s.cancelReadErr
}
if s.resetRemotely {
if s.resetRemotelyErr != nil {
return false, 0, s.resetRemotelyErr
}
if s.closedForShutdown {
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
}
@ -122,13 +118,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
for {
// Stop waiting on errors
if s.closedForShutdown {
if s.closeForShutdownErr != nil {
return false, bytesRead, s.closeForShutdownErr
}
if s.canceledRead {
if s.cancelReadErr != nil {
return false, bytesRead, s.cancelReadErr
}
if s.resetRemotely {
if s.resetRemotelyErr != nil {
return false, bytesRead, s.resetRemotelyErr
}
@ -175,8 +171,9 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
s.readPosInFrame += m
bytesRead += m
// when a RESET_STREAM was received, the was already informed about the final byteOffset for this stream
if !s.resetRemotely {
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
@ -211,10 +208,9 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.canceledRead || s.resetRemotely {
if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return false
}
s.canceledRead = true
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
@ -247,7 +243,7 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /*
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.canceledRead {
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
@ -270,7 +266,7 @@ func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) err
}
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) {
if s.closedForShutdown {
if s.closeForShutdownErr != nil {
return false, nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
@ -280,10 +276,9 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
s.finalOffset = frame.FinalSize
// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotely {
if s.resetRemotelyErr != nil {
return false, nil
}
s.resetRemotely = true
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
@ -310,7 +305,6 @@ func (s *receiveStream) SetReadDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET.
func (s *receiveStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalRead()

View file

@ -40,11 +40,9 @@ type sendStream struct {
cancelWriteErr error
closeForShutdownErr error
closedForShutdown bool // set when CloseForShutdown() is called
finishedWriting bool // set once Close() is called
canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
completed bool // set when this stream has been reported to the streamSender as completed
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
nextFrame *wire.StreamFrame
@ -94,7 +92,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
if s.finishedWriting {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.canceledWrite {
if s.cancelWriteErr != nil {
return 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
@ -153,7 +151,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
}
deadlineTimer.Reset(deadline)
}
if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown {
if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
break
}
}
@ -219,7 +217,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
}
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
if s.canceledWrite || s.closeForShutdownErr != nil {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
return nil, false
}
@ -354,7 +352,7 @@ func (s *sendStream) frameAcked(f wire.Frame) {
f.(*wire.StreamFrame).PutBack()
s.mutex.Lock()
if s.canceledWrite {
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
@ -371,7 +369,7 @@ func (s *sendStream) frameAcked(f wire.Frame) {
}
func (s *sendStream) isNewlyCompleted() bool {
completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
if completed && !s.completed {
s.completed = true
return true
@ -383,7 +381,7 @@ func (s *sendStream) queueRetransmission(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.DataLenPresent = true
s.mutex.Lock()
if s.canceledWrite {
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
@ -399,11 +397,11 @@ func (s *sendStream) queueRetransmission(f wire.Frame) {
func (s *sendStream) Close() error {
s.mutex.Lock()
if s.closedForShutdown {
if s.closeForShutdownErr != nil {
s.mutex.Unlock()
return nil
}
if s.canceledWrite {
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
@ -422,12 +420,11 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
// must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if s.canceledWrite {
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.ctxCancel()
s.canceledWrite = true
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
@ -478,7 +475,6 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.ctxCancel()
s.closedForShutdown = true
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()

View file

@ -341,7 +341,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return false
}
if !s.config.DisableVersionNegotiationPackets {
go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB())
go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB(), v)
}
return false
}
@ -669,8 +669,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
return err
}
func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte) {
s.logger.Debugf("Client offered version %s, sending Version Negotiation")
func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte, v protocol.VersionNumber) {
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.config.Tracer != nil {