pass frame / stream type parsing errors to the hijacker callbacks

When a stream is reset, we might not have received the frame / stream
type yet. The callback might be able to identify if it was a stream
intended for that application by analyzing the stream reset error.
This commit is contained in:
Marten Seemann 2022-05-23 21:56:47 +02:00
parent 5cb2e8265c
commit 96c0daceca
7 changed files with 181 additions and 36 deletions

View file

@ -43,8 +43,8 @@ type roundTripperOpts struct {
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
}
// client is a HTTP3 client doing requests
@ -151,9 +151,8 @@ func (c *client) handleBidirectionalStreams() {
return
}
go func(str quic.Stream) {
for {
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str)
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str, e)
})
if err == errHijacked {
return
@ -162,7 +161,6 @@ func (c *client) handleBidirectionalStreams() {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}
}(str)
}
}
@ -178,6 +176,9 @@ func (c *client) handleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
@ -193,7 +194,7 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))

View file

@ -221,7 +221,8 @@ var _ = Describe("Client", func() {
It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
}
@ -243,7 +244,8 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
}
@ -265,7 +267,8 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
}
@ -284,6 +287,31 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("handles errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
Expect(str).To(Equal(unknownStr))
return false, nil
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})
Context("hijacking unidirectional streams", func() {
@ -321,7 +349,8 @@ var _ = Describe("Client", func() {
It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
@ -343,9 +372,34 @@ var _ = Describe("Client", func() {
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}

View file

@ -14,7 +14,7 @@ import (
// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64
type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)
type frame interface{}
@ -25,11 +25,20 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
for {
t, err := quicvarint.Read(qr)
if err != nil {
if unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(0, err)
if err != nil {
return nil, err
}
if hijacked {
return nil, errHijacked
}
}
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t))
hijacked, err := unknownFrameHandler(FrameType(t), nil)
if err != nil {
return nil, err
}

View file

@ -2,6 +2,7 @@ package http3
import (
"bytes"
"errors"
"fmt"
"io"
@ -11,6 +12,10 @@ import (
. "github.com/onsi/gomega"
)
type errReader struct{ err error }
func (e errReader) Read([]byte) (int, error) { return 0, e.err }
var _ = Describe("Frames", func() {
appendVarInt := func(b []byte, val uint64) []byte {
buf := &bytes.Buffer{}
@ -189,7 +194,8 @@ var _ = Describe("Frames", func() {
buf.Write(customFrameContents)
var called bool
_, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
_, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, 3)
@ -202,6 +208,19 @@ var _ = Describe("Frames", func() {
Expect(called).To(BeTrue())
})
It("passes on errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
var called bool
_, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
called = true
return true, nil
})
Expect(err).To(MatchError(errHijacked))
Expect(called).To(BeTrue())
})
It("reads a frame without hijacking the stream", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
@ -212,7 +231,8 @@ var _ = Describe("Frames", func() {
buf.WriteString("foobar")
var called bool
frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
return false, nil

View file

@ -53,13 +53,17 @@ type RoundTripper struct {
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
// Dial specifies an optional dial function for creating QUIC
// connections for requests.

View file

@ -178,13 +178,17 @@ type Server struct {
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo
@ -457,6 +461,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
@ -471,7 +478,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
@ -510,9 +517,7 @@ func (s *Server) maxHeaderBytes() uint64 {
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType) (processed bool, err error) {
return s.StreamHijacker(ft, conn, str)
}
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
}
frame, err := parseNextFrame(str, ufh)
if err != nil {

View file

@ -279,7 +279,8 @@ var _ = Describe("Server", func() {
It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
}
@ -301,7 +302,8 @@ var _ = Describe("Server", func() {
It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
}
@ -324,7 +326,8 @@ var _ = Describe("Server", func() {
It("cancels writing when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
}
@ -344,6 +347,30 @@ var _ = Describe("Server", func() {
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
defer close(done)
Expect(ft).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true, nil
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})
Context("hijacking unidirectional streams", func() {
@ -365,7 +392,8 @@ var _ = Describe("Server", func() {
It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
@ -386,9 +414,33 @@ var _ = Describe("Server", func() {
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}