implement HTTP/3 stream hijacking

This commit is contained in:
Marten Seemann 2022-03-22 17:40:10 +01:00
parent a54816867f
commit 48a2cce9df
10 changed files with 152 additions and 37 deletions

View file

@ -81,7 +81,7 @@ func (r *body) readImpl(b []byte) (int, error) {
if r.bytesRemainingInFrame == 0 { if r.bytesRemainingInFrame == 0 {
parseLoop: parseLoop:
for { for {
frame, err := parseNextFrame(r.str) frame, err := parseNextFrame(r.str, nil)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -43,6 +43,7 @@ type roundTripperOpts struct {
EnableDatagram bool EnableDatagram bool
MaxHeaderBytes int64 MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64 AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Stream) (hijacked bool, err error)
} }
// client is a HTTP3 client doing requests // client is a HTTP3 client doing requests
@ -118,6 +119,9 @@ func (c *client) dial(ctx context.Context) error {
} }
}() }()
if c.opts.StreamHijacker != nil {
go c.handleBidirectionalStreams()
}
go c.handleUnidirectionalStreams() go c.handleUnidirectionalStreams()
return nil return nil
} }
@ -136,6 +140,30 @@ func (c *client) setupConn() error {
return err return err
} }
func (c *client) handleBidirectionalStreams() {
for {
str, err := c.conn.AcceptStream(context.Background())
if err != nil {
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
return
}
go func(str quic.Stream) {
for {
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
return c.opts.StreamHijacker(ft, str)
})
if err == errHijacked {
return
}
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}
}(str)
}
}
func (c *client) handleUnidirectionalStreams() { func (c *client) handleUnidirectionalStreams() {
for { for {
str, err := c.conn.AcceptUniStream(context.Background()) str, err := c.conn.AcceptUniStream(context.Background())
@ -165,7 +193,7 @@ func (c *client) handleUnidirectionalStreams() {
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return return
} }
f, err := parseNextFrame(str) f, err := parseNextFrame(str, nil)
if err != nil { if err != nil {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
return return
@ -276,7 +304,7 @@ func (c *client) doRequest(
return nil, newStreamError(errorInternalError, err) return nil, newStreamError(errorInternalError, err)
} }
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
if err != nil { if err != nil {
return nil, newStreamError(errorFrameError, err) return nil, newStreamError(errorFrameError, err)
} }

View file

@ -410,7 +410,7 @@ var _ = Describe("Client", func() {
fields := make(map[string]string) fields := make(map[string]string)
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)

View file

@ -2,6 +2,7 @@ package http3
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -10,15 +11,34 @@ import (
"github.com/lucas-clemente/quic-go/quicvarint" "github.com/lucas-clemente/quic-go/quicvarint"
) )
// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64
type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)
type frame interface{} type frame interface{}
func parseNextFrame(r io.Reader) (frame, error) { var errHijacked = errors.New("hijacked")
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
qr := quicvarint.NewReader(r) qr := quicvarint.NewReader(r)
for { for {
t, err := quicvarint.Read(qr) t, err := quicvarint.Read(qr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t))
if err != nil {
return nil, err
}
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
if hijacked {
return nil, errHijacked
}
continue
}
l, err := quicvarint.Read(qr) l, err := quicvarint.Read(qr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -32,18 +52,13 @@ func parseNextFrame(r io.Reader) (frame, error) {
case 0x4: case 0x4:
return parseSettingsFrame(r, l) return parseSettingsFrame(r, l)
case 0x3: // CANCEL_PUSH case 0x3: // CANCEL_PUSH
fallthrough
case 0x5: // PUSH_PROMISE case 0x5: // PUSH_PROMISE
fallthrough
case 0x7: // GOAWAY case 0x7: // GOAWAY
fallthrough
case 0xd: // MAX_PUSH_ID case 0xd: // MAX_PUSH_ID
fallthrough }
default: // skip over unknown frames
// skip over unknown frames if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil {
if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { return nil, err
return nil, err
}
} }
} }
} }

View file

@ -24,7 +24,7 @@ var _ = Describe("Frames", func() {
data = append(data, make([]byte, 0x42)...) data = append(data, make([]byte, 0x42)...)
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
(&dataFrame{Length: 0x1234}).Write(buf) (&dataFrame{Length: 0x1234}).Write(buf)
frame, err := parseNextFrame(buf) frame, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234)))
@ -34,7 +34,7 @@ var _ = Describe("Frames", func() {
It("parses", func() { It("parses", func() {
data := appendVarInt(nil, 0) // type byte data := appendVarInt(nil, 0) // type byte
data = appendVarInt(data, 0x1337) data = appendVarInt(data, 0x1337)
frame, err := parseNextFrame(bytes.NewReader(data)) frame, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337)))
@ -43,7 +43,7 @@ var _ = Describe("Frames", func() {
It("writes", func() { It("writes", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
(&dataFrame{Length: 0xdeadbeef}).Write(buf) (&dataFrame{Length: 0xdeadbeef}).Write(buf)
frame, err := parseNextFrame(buf) frame, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
@ -55,7 +55,7 @@ var _ = Describe("Frames", func() {
It("parses", func() { It("parses", func() {
data := appendVarInt(nil, 1) // type byte data := appendVarInt(nil, 1) // type byte
data = appendVarInt(data, 0x1337) data = appendVarInt(data, 0x1337)
frame, err := parseNextFrame(bytes.NewReader(data)) frame, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337)))
@ -64,7 +64,7 @@ var _ = Describe("Frames", func() {
It("writes", func() { It("writes", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
(&headersFrame{Length: 0xdeadbeef}).Write(buf) (&headersFrame{Length: 0xdeadbeef}).Write(buf)
frame, err := parseNextFrame(buf) frame, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
@ -81,7 +81,7 @@ var _ = Describe("Frames", func() {
data := appendVarInt(nil, 4) // type byte data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings))) data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...) data = append(data, settings...)
frame, err := parseNextFrame(bytes.NewReader(data)) frame, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{}))
sf := frame.(*settingsFrame) sf := frame.(*settingsFrame)
@ -97,7 +97,7 @@ var _ = Describe("Frames", func() {
data := appendVarInt(nil, 4) // type byte data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings))) data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...) data = append(data, settings...)
_, err := parseNextFrame(bytes.NewReader(data)) _, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).To(MatchError("duplicate setting: 13")) Expect(err).To(MatchError("duplicate setting: 13"))
}) })
@ -109,7 +109,7 @@ var _ = Describe("Frames", func() {
}} }}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
sf.Write(buf) sf.Write(buf)
frame, err := parseNextFrame(buf) frame, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(sf)) Expect(frame).To(Equal(sf))
}) })
@ -123,13 +123,13 @@ var _ = Describe("Frames", func() {
sf.Write(buf) sf.Write(buf)
data := buf.Bytes() data := buf.Bytes()
_, err := parseNextFrame(bytes.NewReader(data)) _, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := range data { for i := range data {
b := make([]byte, i) b := make([]byte, i)
copy(b, data[:i]) copy(b, data[:i])
_, err := parseNextFrame(bytes.NewReader(b)) _, err := parseNextFrame(bytes.NewReader(b), nil)
Expect(err).To(MatchError(io.EOF)) Expect(err).To(MatchError(io.EOF))
} }
}) })
@ -141,7 +141,7 @@ var _ = Describe("Frames", func() {
data := appendVarInt(nil, 4) // type byte data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings))) data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...) data = append(data, settings...)
f, err := parseNextFrame(bytes.NewReader(data)) f, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) Expect(f).To(BeAssignableToTypeOf(&settingsFrame{}))
sf := f.(*settingsFrame) sf := f.(*settingsFrame)
@ -156,7 +156,7 @@ var _ = Describe("Frames", func() {
data := appendVarInt(nil, 4) // type byte data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings))) data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...) data = append(data, settings...)
_, err := parseNextFrame(bytes.NewReader(data)) _, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram)))
}) })
@ -166,7 +166,7 @@ var _ = Describe("Frames", func() {
data := appendVarInt(nil, 4) // type byte data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings))) data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...) data = append(data, settings...)
_, err := parseNextFrame(bytes.NewReader(data)) _, err := parseNextFrame(bytes.NewReader(data), nil)
Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337"))
}) })
@ -174,10 +174,55 @@ var _ = Describe("Frames", func() {
sf := &settingsFrame{Datagram: true} sf := &settingsFrame{Datagram: true}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
sf.Write(buf) sf.Write(buf)
frame, err := parseNextFrame(buf) frame, err := parseNextFrame(buf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(sf)) Expect(frame).To(Equal(sf))
}) })
}) })
}) })
Context("hijacking", func() {
It("reads a frame without hijacking the stream", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
customFrameContents := []byte("foobar")
buf.Write(customFrameContents)
var called bool
_, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, 3)
_, err = io.ReadFull(buf, b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foo"))
return true, nil
})
Expect(err).To(MatchError(errHijacked))
Expect(called).To(BeTrue())
})
It("reads a frame without hijacking the stream", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
customFrameContents := []byte("custom frame")
buf.Write(customFrameContents)
(&dataFrame{Length: 6}).Write(buf)
buf.WriteString("foobar")
var called bool
frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, len(customFrameContents))
_, err = io.ReadFull(buf, b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal(string(customFrameContents)))
return false, nil
})
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(&dataFrame{Length: 6}))
Expect(called).To(BeTrue())
})
})
}) })

View file

@ -30,7 +30,7 @@ var _ = Describe("Request Writer", func() {
) )
decode := func(str io.Reader) map[string]string { decode := func(str io.Reader) map[string]string {
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)
@ -85,7 +85,7 @@ var _ = Describe("Request Writer", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0)) Expect(contentLength).To(BeNumerically(">", 0))
frame, err := parseNextFrame(strBuf) frame, err := parseNextFrame(strBuf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
@ -102,7 +102,7 @@ var _ = Describe("Request Writer", func() {
headerFields := decode(strBuf) headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
frame, err := parseNextFrame(strBuf) frame, err := parseNextFrame(strBuf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))

View file

@ -33,7 +33,7 @@ var _ = Describe("Response Writer", func() {
fields := make(map[string][]string) fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)
@ -49,7 +49,7 @@ var _ = Describe("Response Writer", func() {
} }
getData := func(str io.Reader) []byte { getData := func(str io.Reader) []byte {
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
df := frame.(*dataFrame) df := frame.(*dataFrame)

View file

@ -51,6 +51,13 @@ type RoundTripper struct {
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
AdditionalSettings map[uint64]uint64 AdditionalSettings map[uint64]uint64
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// Callers can either process the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Stream) (hijacked bool, err error)
// Dial specifies an optional dial function for creating QUIC // Dial specifies an optional dial function for creating QUIC
// connections for requests. // connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used. // If Dial is nil, quic.DialAddrEarlyContext will be used.
@ -146,6 +153,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
EnableDatagram: r.EnableDatagrams, EnableDatagram: r.EnableDatagrams,
DisableCompression: r.DisableCompression, DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes, MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
}, },
r.QuicConfig, r.QuicConfig,
r.Dial, r.Dial,

View file

@ -144,6 +144,13 @@ type Server struct {
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
AdditionalSettings map[uint64]uint64 AdditionalSettings map[uint64]uint64
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// Callers can either process the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
mutex sync.RWMutex mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo listeners map[*quic.EarlyListener]listenerInfo
@ -186,7 +193,7 @@ func (s *Server) Serve(conn net.PacketConn) error {
return s.serveConn(s.TLSConfig, conn) return s.serveConn(s.TLSConfig, conn)
} }
// Serve an existing QUIC listener. // ServeListener serves an existing QUIC listener.
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener. // and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener. // Closing the server does close the listener.
@ -367,6 +374,9 @@ func (s *Server) handleConn(conn quic.EarlyConnection) {
rerr := s.handleRequest(conn, str, decoder, func() { rerr := s.handleRequest(conn, str, decoder, func() {
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
}) })
if rerr.err == errHijacked {
return
}
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
s.logger.Debugf("Handling request failed: %s", err) s.logger.Debugf("Handling request failed: %s", err)
if rerr.streamErr != 0 { if rerr.streamErr != 0 {
@ -414,7 +424,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return return
} }
f, err := parseNextFrame(str) f, err := parseNextFrame(str, nil)
if err != nil { if err != nil {
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
return return
@ -445,8 +455,17 @@ func (s *Server) maxHeaderBytes() uint64 {
} }
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
frame, err := parseNextFrame(str) var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType) (processed bool, err error) {
return s.StreamHijacker(ft, conn, str)
}
}
frame, err := parseNextFrame(str, ufh)
if err != nil { if err != nil {
if err == errHijacked {
return requestError{err: errHijacked}
}
return newStreamError(errorRequestIncomplete, err) return newStreamError(errorRequestIncomplete, err)
} }
hf, ok := frame.(*headersFrame) hf, ok := frame.(*headersFrame)

View file

@ -117,7 +117,7 @@ var _ = Describe("Server", func() {
fields := make(map[string][]string) fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil) decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str) frame, err := parseNextFrame(str, nil)
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame) headersFrame := frame.(*headersFrame)