diff --git a/http3/body.go b/http3/body.go index b6363dc6..b3cdf801 100644 --- a/http3/body.go +++ b/http3/body.go @@ -81,7 +81,7 @@ func (r *body) readImpl(b []byte) (int, error) { if r.bytesRemainingInFrame == 0 { parseLoop: for { - frame, err := parseNextFrame(r.str) + frame, err := parseNextFrame(r.str, nil) if err != nil { return 0, err } diff --git a/http3/client.go b/http3/client.go index bf12ecd1..f85d8d55 100644 --- a/http3/client.go +++ b/http3/client.go @@ -43,6 +43,7 @@ type roundTripperOpts struct { EnableDatagram bool MaxHeaderBytes int64 AdditionalSettings map[uint64]uint64 + StreamHijacker func(FrameType, quic.Stream) (hijacked bool, err error) } // client is a HTTP3 client doing requests @@ -118,6 +119,9 @@ func (c *client) dial(ctx context.Context) error { } }() + if c.opts.StreamHijacker != nil { + go c.handleBidirectionalStreams() + } go c.handleUnidirectionalStreams() return nil } @@ -136,6 +140,30 @@ func (c *client) setupConn() error { return err } +func (c *client) handleBidirectionalStreams() { + for { + str, err := c.conn.AcceptStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting bidirectional stream failed: %s", err) + return + } + go func(str quic.Stream) { + for { + _, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) { + return c.opts.StreamHijacker(ft, str) + }) + if err == errHijacked { + return + } + if err != nil { + c.logger.Debugf("error handling stream: %s", err) + } + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + } + }(str) + } +} + func (c *client) handleUnidirectionalStreams() { for { str, err := c.conn.AcceptUniStream(context.Background()) @@ -165,7 +193,7 @@ func (c *client) handleUnidirectionalStreams() { str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str) + f, err := parseNextFrame(str, nil) if err != nil { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return @@ -276,7 +304,7 @@ func (c *client) doRequest( return nil, newStreamError(errorInternalError, err) } - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) if err != nil { return nil, newStreamError(errorFrameError, err) } diff --git a/http3/client_test.go b/http3/client_test.go index 28733ac0..b1d8b48b 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -410,7 +410,7 @@ var _ = Describe("Client", func() { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) diff --git a/http3/frames.go b/http3/frames.go index c1225b8d..5fb4f082 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -10,15 +11,34 @@ import ( "github.com/lucas-clemente/quic-go/quicvarint" ) +// FrameType is the frame type of a HTTP/3 frame +type FrameType uint64 + +type unknownFrameHandlerFunc func(FrameType) (processed bool, err error) + type frame interface{} -func parseNextFrame(r io.Reader) (frame, error) { +var errHijacked = errors.New("hijacked") + +func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { qr := quicvarint.NewReader(r) for { t, err := quicvarint.Read(qr) if err != nil { return nil, err } + // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec + if t > 0xd && unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(FrameType(t)) + if err != nil { + return nil, err + } + // If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it. + if hijacked { + return nil, errHijacked + } + continue + } l, err := quicvarint.Read(qr) if err != nil { return nil, err @@ -32,18 +52,13 @@ func parseNextFrame(r io.Reader) (frame, error) { case 0x4: return parseSettingsFrame(r, l) case 0x3: // CANCEL_PUSH - fallthrough case 0x5: // PUSH_PROMISE - fallthrough case 0x7: // GOAWAY - fallthrough case 0xd: // MAX_PUSH_ID - fallthrough - default: - // skip over unknown frames - if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { - return nil, err - } + } + // skip over unknown frames + if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { + return nil, err } } } diff --git a/http3/frames_test.go b/http3/frames_test.go index 014b6a4a..40ca3c12 100644 --- a/http3/frames_test.go +++ b/http3/frames_test.go @@ -24,7 +24,7 @@ var _ = Describe("Frames", func() { data = append(data, make([]byte, 0x42)...) buf := bytes.NewBuffer(data) (&dataFrame{Length: 0x1234}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) @@ -34,7 +34,7 @@ var _ = Describe("Frames", func() { It("parses", func() { data := appendVarInt(nil, 0) // type byte data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) @@ -43,7 +43,7 @@ var _ = Describe("Frames", func() { It("writes", func() { buf := &bytes.Buffer{} (&dataFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) @@ -55,7 +55,7 @@ var _ = Describe("Frames", func() { It("parses", func() { data := appendVarInt(nil, 1) // type byte data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) @@ -64,7 +64,7 @@ var _ = Describe("Frames", func() { It("writes", func() { buf := &bytes.Buffer{} (&headersFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) @@ -81,7 +81,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) sf := frame.(*settingsFrame) @@ -97,7 +97,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError("duplicate setting: 13")) }) @@ -109,7 +109,7 @@ var _ = Describe("Frames", func() { }} buf := &bytes.Buffer{} sf.Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) @@ -123,13 +123,13 @@ var _ = Describe("Frames", func() { sf.Write(buf) data := buf.Bytes() - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) for i := range data { b := make([]byte, i) copy(b, data[:i]) - _, err := parseNextFrame(bytes.NewReader(b)) + _, err := parseNextFrame(bytes.NewReader(b), nil) Expect(err).To(MatchError(io.EOF)) } }) @@ -141,7 +141,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - f, err := parseNextFrame(bytes.NewReader(data)) + f, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) sf := f.(*settingsFrame) @@ -156,7 +156,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) }) @@ -166,7 +166,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) }) @@ -174,10 +174,55 @@ var _ = Describe("Frames", func() { sf := &settingsFrame{Datagram: true} buf := &bytes.Buffer{} sf.Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) }) }) + + Context("hijacking", func() { + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("foobar") + buf.Write(customFrameContents) + + var called bool + _, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) { + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foo")) + return true, nil + }) + Expect(err).To(MatchError(errHijacked)) + Expect(called).To(BeTrue()) + }) + + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("custom frame") + buf.Write(customFrameContents) + (&dataFrame{Length: 6}).Write(buf) + buf.WriteString("foobar") + + var called bool + frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) { + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, len(customFrameContents)) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal(string(customFrameContents))) + return false, nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(&dataFrame{Length: 6})) + Expect(called).To(BeTrue()) + }) + }) }) diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 95adc951..9a1e718e 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -30,7 +30,7 @@ var _ = Describe("Request Writer", func() { ) decode := func(str io.Reader) map[string]string { - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -85,7 +85,7 @@ var _ = Describe("Request Writer", func() { Expect(err).ToNot(HaveOccurred()) Expect(contentLength).To(BeNumerically(">", 0)) - frame, err := parseNextFrame(strBuf) + frame, err := parseNextFrame(strBuf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) @@ -102,7 +102,7 @@ var _ = Describe("Request Writer", func() { headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - frame, err := parseNextFrame(strBuf) + frame, err := parseNextFrame(strBuf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index fb2ff186..f1f454cc 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -33,7 +33,7 @@ var _ = Describe("Response Writer", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -49,7 +49,7 @@ var _ = Describe("Response Writer", func() { } getData := func(str io.Reader) []byte { - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) df := frame.(*dataFrame) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index f9c0c394..02914f9e 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -51,6 +51,13 @@ type RoundTripper struct { // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. AdditionalSettings map[uint64]uint64 + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // Callers can either process the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Stream) (hijacked bool, err error) + // Dial specifies an optional dial function for creating QUIC // connections for requests. // If Dial is nil, quic.DialAddrEarlyContext will be used. @@ -146,6 +153,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr EnableDatagram: r.EnableDatagrams, DisableCompression: r.DisableCompression, MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, }, r.QuicConfig, r.Dial, diff --git a/http3/server.go b/http3/server.go index 1873f895..69e5ee4c 100644 --- a/http3/server.go +++ b/http3/server.go @@ -144,6 +144,13 @@ type Server struct { // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. AdditionalSettings map[uint64]uint64 + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // Callers can either process the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + mutex sync.RWMutex listeners map[*quic.EarlyListener]listenerInfo @@ -186,7 +193,7 @@ func (s *Server) Serve(conn net.PacketConn) error { return s.serveConn(s.TLSConfig, conn) } -// Serve an existing QUIC listener. +// ServeListener serves an existing QUIC listener. // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // and use it to construct a http3-friendly QUIC listener. // Closing the server does close the listener. @@ -367,6 +374,9 @@ func (s *Server) handleConn(conn quic.EarlyConnection) { rerr := s.handleRequest(conn, str, decoder, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if rerr.err == errHijacked { + return + } if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { s.logger.Debugf("Handling request failed: %s", err) if rerr.streamErr != 0 { @@ -414,7 +424,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str) + f, err := parseNextFrame(str, nil) if err != nil { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return @@ -445,8 +455,17 @@ func (s *Server) maxHeaderBytes() uint64 { } func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { - frame, err := parseNextFrame(str) + var ufh unknownFrameHandlerFunc + if s.StreamHijacker != nil { + ufh = func(ft FrameType) (processed bool, err error) { + return s.StreamHijacker(ft, conn, str) + } + } + frame, err := parseNextFrame(str, ufh) if err != nil { + if err == errHijacked { + return requestError{err: errHijacked} + } return newStreamError(errorRequestIncomplete, err) } hf, ok := frame.(*headersFrame) diff --git a/http3/server_test.go b/http3/server_test.go index 2c722e65..fd6091df 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -117,7 +117,7 @@ var _ = Describe("Server", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame)