mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
implement HTTP/3 unidirectional stream hijacking (#3389)
* implement HTTP/3 unistream hijacking * Apply suggestions from code review Fixed name consistency. Co-authored-by: Marten Seemann <martenseemann@gmail.com> * rename unistream to unidirectional stream Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
6d4a694183
commit
1a0d577854
5 changed files with 168 additions and 2 deletions
|
@ -44,6 +44,7 @@ type roundTripperOpts struct {
|
||||||
MaxHeaderBytes int64
|
MaxHeaderBytes int64
|
||||||
AdditionalSettings map[uint64]uint64
|
AdditionalSettings map[uint64]uint64
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
||||||
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// client is a HTTP3 client doing requests
|
// client is a HTTP3 client doing requests
|
||||||
|
@ -174,7 +175,7 @@ func (c *client) handleUnidirectionalStreams() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func(str quic.ReceiveStream) {
|
||||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||||
|
@ -192,6 +193,9 @@ func (c *client) handleUnidirectionalStreams() {
|
||||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
|
||||||
|
return
|
||||||
|
}
|
||||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -214,7 +218,7 @@ func (c *client) handleUnidirectionalStreams() {
|
||||||
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
|
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
|
||||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
||||||
}
|
}
|
||||||
}()
|
}(str)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -185,6 +185,89 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("hijacking unidirectional streams", func() {
|
||||||
|
var (
|
||||||
|
request *http.Request
|
||||||
|
conn *mockquic.MockEarlyConnection
|
||||||
|
settingsFrameWritten chan struct{}
|
||||||
|
)
|
||||||
|
testDone := make(chan struct{})
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
testDone = make(chan struct{})
|
||||||
|
settingsFrameWritten = make(chan struct{})
|
||||||
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
close(settingsFrameWritten)
|
||||||
|
})
|
||||||
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
||||||
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
||||||
|
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
|
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
testDone <- struct{}{}
|
||||||
|
Eventually(settingsFrameWritten).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||||
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
|
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
||||||
|
streamTypeChan <- st
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
quicvarint.Write(buf, 0x54)
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
return unknownStr, nil
|
||||||
|
})
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError("done"))
|
||||||
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
|
||||||
|
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||||
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
|
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
||||||
|
streamTypeChan <- st
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
quicvarint.Write(buf, 0x54)
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
|
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
return unknownStr, nil
|
||||||
|
})
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError("done"))
|
||||||
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("control stream handling", func() {
|
Context("control stream handling", func() {
|
||||||
var (
|
var (
|
||||||
request *http.Request
|
request *http.Request
|
||||||
|
|
|
@ -58,6 +58,9 @@ type RoundTripper struct {
|
||||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
||||||
|
|
||||||
|
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
||||||
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
||||||
|
|
||||||
// Dial specifies an optional dial function for creating QUIC
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
// connections for requests.
|
// connections for requests.
|
||||||
// If Dial is nil, quic.DialAddrEarlyContext will be used.
|
// If Dial is nil, quic.DialAddrEarlyContext will be used.
|
||||||
|
@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
||||||
DisableCompression: r.DisableCompression,
|
DisableCompression: r.DisableCompression,
|
||||||
MaxHeaderBytes: r.MaxResponseHeaderBytes,
|
MaxHeaderBytes: r.MaxResponseHeaderBytes,
|
||||||
StreamHijacker: r.StreamHijacker,
|
StreamHijacker: r.StreamHijacker,
|
||||||
|
UniStreamHijacker: r.UniStreamHijacker,
|
||||||
},
|
},
|
||||||
r.QuicConfig,
|
r.QuicConfig,
|
||||||
r.Dial,
|
r.Dial,
|
||||||
|
|
|
@ -33,6 +33,9 @@ const (
|
||||||
nextProtoH3 = "h3"
|
nextProtoH3 = "h3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StreamType is the stream type of a unidirectional stream.
|
||||||
|
type StreamType uint64
|
||||||
|
|
||||||
const (
|
const (
|
||||||
streamTypeControlStream = 0
|
streamTypeControlStream = 0
|
||||||
streamTypePushStream = 1
|
streamTypePushStream = 1
|
||||||
|
@ -151,6 +154,9 @@ type Server struct {
|
||||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
|
||||||
|
|
||||||
|
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
||||||
|
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
listeners map[*quic.EarlyListener]listenerInfo
|
listeners map[*quic.EarlyListener]listenerInfo
|
||||||
|
|
||||||
|
@ -421,6 +427,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
|
||||||
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
|
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
|
||||||
|
return
|
||||||
|
}
|
||||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -238,6 +238,72 @@ var _ = Describe("Server", func() {
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("hijacking unidirectional streams", func() {
|
||||||
|
var conn *mockquic.MockEarlyConnection
|
||||||
|
testDone := make(chan struct{})
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
testDone = make(chan struct{})
|
||||||
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
||||||
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
controlStr.EXPECT().Write(gomock.Any())
|
||||||
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
||||||
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
|
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
|
||||||
|
conn.EXPECT().LocalAddr().AnyTimes()
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() { testDone <- struct{}{} })
|
||||||
|
|
||||||
|
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||||
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
|
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
||||||
|
streamTypeChan <- st
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
quicvarint.Write(buf, 0x54)
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
return unknownStr, nil
|
||||||
|
})
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
s.handleConn(conn)
|
||||||
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
|
||||||
|
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||||
|
streamTypeChan := make(chan StreamType, 1)
|
||||||
|
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
|
||||||
|
streamTypeChan <- st
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
quicvarint.Write(buf, 0x54)
|
||||||
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||||
|
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||||
|
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
return unknownStr, nil
|
||||||
|
})
|
||||||
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||||
|
<-testDone
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
})
|
||||||
|
s.handleConn(conn)
|
||||||
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||||
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("control stream handling", func() {
|
Context("control stream handling", func() {
|
||||||
var conn *mockquic.MockEarlyConnection
|
var conn *mockquic.MockEarlyConnection
|
||||||
testDone := make(chan struct{})
|
testDone := make(chan struct{})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue