mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
pass frame / stream type parsing errors to the hijacker callbacks
When a stream is reset, we might not have received the frame / stream type yet. The callback might be able to identify if it was a stream intended for that application by analyzing the stream reset error.
This commit is contained in:
parent
5cb2e8265c
commit
96c0daceca
7 changed files with 181 additions and 36 deletions
|
@ -43,8 +43,8 @@ type roundTripperOpts struct {
|
||||||
EnableDatagram bool
|
EnableDatagram bool
|
||||||
MaxHeaderBytes int64
|
MaxHeaderBytes int64
|
||||||
AdditionalSettings map[uint64]uint64
|
AdditionalSettings map[uint64]uint64
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// client is a HTTP3 client doing requests
|
// client is a HTTP3 client doing requests
|
||||||
|
@ -151,18 +151,16 @@ func (c *client) handleBidirectionalStreams() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func(str quic.Stream) {
|
go func(str quic.Stream) {
|
||||||
for {
|
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
|
||||||
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
|
return c.opts.StreamHijacker(ft, c.conn, str, e)
|
||||||
return c.opts.StreamHijacker(ft, c.conn, str)
|
})
|
||||||
})
|
if err == errHijacked {
|
||||||
if err == errHijacked {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
c.logger.Debugf("error handling stream: %s", err)
|
|
||||||
}
|
|
||||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Debugf("error handling stream: %s", err)
|
||||||
|
}
|
||||||
|
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||||
}(str)
|
}(str)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -178,6 +176,9 @@ func (c *client) handleUnidirectionalStreams() {
|
||||||
go func(str quic.ReceiveStream) {
|
go func(str quic.ReceiveStream) {
|
||||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -193,7 +194,7 @@ func (c *client) handleUnidirectionalStreams() {
|
||||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
|
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
|
|
|
@ -221,7 +221,8 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("hijacks a bidirectional stream of unknown frame type", func() {
|
It("hijacks a bidirectional stream of unknown frame type", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -243,7 +244,8 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
|
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -265,7 +267,8 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("closes the connection when hijacker returned error", func() {
|
It("closes the connection when hijacker returned error", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return false, errors.New("error in hijacker")
|
return false, errors.New("error in hijacker")
|
||||||
}
|
}
|
||||||
|
@ -284,6 +287,31 @@ var _ = Describe("Client", func() {
|
||||||
Expect(err).To(MatchError("done"))
|
Expect(err).To(MatchError("done"))
|
||||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("handles errors that occur when reading the frame type", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
done := make(chan struct{})
|
||||||
|
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
defer close(done)
|
||||||
|
Expect(e).To(MatchError(testErr))
|
||||||
|
Expect(ft).To(BeZero())
|
||||||
|
Expect(str).To(Equal(unknownStr))
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
|
||||||
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
||||||
|
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||||
|
_, err := client.RoundTripOpt(request, RoundTripOpt{})
|
||||||
|
Expect(err).To(MatchError("done"))
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("hijacking unidirectional streams", func() {
|
Context("hijacking unidirectional streams", func() {
|
||||||
|
@ -321,7 +349,8 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("hijacks an unidirectional stream of unknown stream type", func() {
|
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||||
streamTypeChan := make(chan StreamType, 1)
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
streamTypeChan <- st
|
streamTypeChan <- st
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -343,9 +372,34 @@ var _ = Describe("Client", func() {
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("handles errors that occur when reading the stream type", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
done := make(chan struct{})
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
|
||||||
|
defer close(done)
|
||||||
|
Expect(st).To(BeZero())
|
||||||
|
Expect(str).To(Equal(unknownStr))
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||||
|
Expect(err).To(MatchError("done"))
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
|
||||||
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||||
streamTypeChan := make(chan StreamType, 1)
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
streamTypeChan <- st
|
streamTypeChan <- st
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
// FrameType is the frame type of a HTTP/3 frame
|
// FrameType is the frame type of a HTTP/3 frame
|
||||||
type FrameType uint64
|
type FrameType uint64
|
||||||
|
|
||||||
type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)
|
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)
|
||||||
|
|
||||||
type frame interface{}
|
type frame interface{}
|
||||||
|
|
||||||
|
@ -25,11 +25,20 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
|
||||||
for {
|
for {
|
||||||
t, err := quicvarint.Read(qr)
|
t, err := quicvarint.Read(qr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if unknownFrameHandler != nil {
|
||||||
|
hijacked, err := unknownFrameHandler(0, err)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if hijacked {
|
||||||
|
return nil, errHijacked
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
|
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
|
||||||
if t > 0xd && unknownFrameHandler != nil {
|
if t > 0xd && unknownFrameHandler != nil {
|
||||||
hijacked, err := unknownFrameHandler(FrameType(t))
|
hijacked, err := unknownFrameHandler(FrameType(t), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
@ -11,6 +12,10 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type errReader struct{ err error }
|
||||||
|
|
||||||
|
func (e errReader) Read([]byte) (int, error) { return 0, e.err }
|
||||||
|
|
||||||
var _ = Describe("Frames", func() {
|
var _ = Describe("Frames", func() {
|
||||||
appendVarInt := func(b []byte, val uint64) []byte {
|
appendVarInt := func(b []byte, val uint64) []byte {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
@ -189,7 +194,8 @@ var _ = Describe("Frames", func() {
|
||||||
buf.Write(customFrameContents)
|
buf.Write(customFrameContents)
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
_, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
|
_, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
Expect(ft).To(BeEquivalentTo(1337))
|
Expect(ft).To(BeEquivalentTo(1337))
|
||||||
called = true
|
called = true
|
||||||
b := make([]byte, 3)
|
b := make([]byte, 3)
|
||||||
|
@ -202,6 +208,19 @@ var _ = Describe("Frames", func() {
|
||||||
Expect(called).To(BeTrue())
|
Expect(called).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("passes on errors that occur when reading the frame type", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
var called bool
|
||||||
|
_, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).To(MatchError(testErr))
|
||||||
|
Expect(ft).To(BeZero())
|
||||||
|
called = true
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
Expect(err).To(MatchError(errHijacked))
|
||||||
|
Expect(called).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
It("reads a frame without hijacking the stream", func() {
|
It("reads a frame without hijacking the stream", func() {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
quicvarint.Write(buf, 1337)
|
quicvarint.Write(buf, 1337)
|
||||||
|
@ -212,7 +231,8 @@ var _ = Describe("Frames", func() {
|
||||||
buf.WriteString("foobar")
|
buf.WriteString("foobar")
|
||||||
|
|
||||||
var called bool
|
var called bool
|
||||||
frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
|
frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
Expect(ft).To(BeEquivalentTo(1337))
|
Expect(ft).To(BeEquivalentTo(1337))
|
||||||
called = true
|
called = true
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|
|
@ -53,13 +53,17 @@ type RoundTripper struct {
|
||||||
|
|
||||||
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
|
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
|
||||||
// It is called right after parsing the frame type.
|
// It is called right after parsing the frame type.
|
||||||
|
// If parsing the frame type fails, the error is passed to the callback.
|
||||||
|
// In that case, the frame type will not be set.
|
||||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||||
// (by returning hijacked false).
|
// (by returning hijacked false).
|
||||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||||
|
|
||||||
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
||||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
// If parsing the stream type fails, the error is passed to the callback.
|
||||||
|
// In that case, the stream type will not be set.
|
||||||
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||||
|
|
||||||
// Dial specifies an optional dial function for creating QUIC
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
// connections for requests.
|
// connections for requests.
|
||||||
|
|
|
@ -178,13 +178,17 @@ type Server struct {
|
||||||
|
|
||||||
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
|
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
|
||||||
// It is called right after parsing the frame type.
|
// It is called right after parsing the frame type.
|
||||||
|
// If parsing the frame type fails, the error is passed to the callback.
|
||||||
|
// In that case, the frame type will not be set.
|
||||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||||
// (by returning hijacked false).
|
// (by returning hijacked false).
|
||||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||||
|
|
||||||
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
|
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
|
||||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
// If parsing the stream type fails, the error is passed to the callback.
|
||||||
|
// In that case, the stream type will not be set.
|
||||||
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
listeners map[*quic.EarlyListener]listenerInfo
|
listeners map[*quic.EarlyListener]listenerInfo
|
||||||
|
@ -457,6 +461,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
|
||||||
go func(str quic.ReceiveStream) {
|
go func(str quic.ReceiveStream) {
|
||||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -471,7 +478,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
|
||||||
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
|
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
|
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
|
@ -510,9 +517,7 @@ func (s *Server) maxHeaderBytes() uint64 {
|
||||||
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
||||||
var ufh unknownFrameHandlerFunc
|
var ufh unknownFrameHandlerFunc
|
||||||
if s.StreamHijacker != nil {
|
if s.StreamHijacker != nil {
|
||||||
ufh = func(ft FrameType) (processed bool, err error) {
|
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
|
||||||
return s.StreamHijacker(ft, conn, str)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
frame, err := parseNextFrame(str, ufh)
|
frame, err := parseNextFrame(str, ufh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -279,7 +279,8 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("hijacks a bidirectional stream of unknown frame type", func() {
|
It("hijacks a bidirectional stream of unknown frame type", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -301,7 +302,8 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
|
It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -324,7 +326,8 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("cancels writing when hijacker returned error", func() {
|
It("cancels writing when hijacker returned error", func() {
|
||||||
frameTypeChan := make(chan FrameType, 1)
|
frameTypeChan := make(chan FrameType, 1)
|
||||||
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
|
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||||
|
Expect(e).ToNot(HaveOccurred())
|
||||||
frameTypeChan <- ft
|
frameTypeChan <- ft
|
||||||
return false, errors.New("error in hijacker")
|
return false, errors.New("error in hijacker")
|
||||||
}
|
}
|
||||||
|
@ -344,6 +347,30 @@ var _ = Describe("Server", func() {
|
||||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("handles errors that occur when reading the stream type", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
done := make(chan struct{})
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
|
||||||
|
defer close(done)
|
||||||
|
Expect(ft).To(BeZero())
|
||||||
|
Expect(str).To(Equal(unknownStr))
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
|
||||||
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
||||||
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
s.handleConn(conn)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("hijacking unidirectional streams", func() {
|
Context("hijacking unidirectional streams", func() {
|
||||||
|
@ -365,7 +392,8 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("hijacks an unidirectional stream of unknown stream type", func() {
|
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||||
streamTypeChan := make(chan StreamType, 1)
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
streamTypeChan <- st
|
streamTypeChan <- st
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -386,9 +414,33 @@ var _ = Describe("Server", func() {
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("handles errors that occur when reading the stream type", func() {
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
done := make(chan struct{})
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
|
||||||
|
defer close(done)
|
||||||
|
Expect(st).To(BeZero())
|
||||||
|
Expect(str).To(Equal(unknownStr))
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
s.handleConn(conn)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
|
||||||
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||||
streamTypeChan := make(chan StreamType, 1)
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
streamTypeChan <- st
|
streamTypeChan <- st
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue