diff --git a/http3/request.go b/http3/request.go index f82feb8f..9fb369c7 100644 --- a/http3/request.go +++ b/http3/request.go @@ -13,118 +13,147 @@ import ( "github.com/quic-go/qpack" ) -func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { - var path, authority, method, protocol, scheme, contentLengthStr string +type header struct { + // Pseudo header fields defined in RFC 9114 + Path string + Method string + Authority string + Scheme string + Status string + // for Extended connect + Protocol string + // parsed and deduplicated + ContentLength int64 + // all non-pseudo headers + Headers http.Header +} - httpHeaders := http.Header{} +func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) { + hdr := header{Headers: make(http.Header, len(headers))} var readFirstRegularHeader bool + var contentLengthStr string for _, h := range headers { // field names need to be lowercase, see section 4.2 of RFC 9114 if strings.ToLower(h.Name) != h.Name { - return nil, fmt.Errorf("header field is not lower-case: %s", h.Name) + return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) } if !httpguts.ValidHeaderFieldValue(h.Value) { - return nil, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) + return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) } if h.IsPseudo() { if readFirstRegularHeader { // all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114 - return nil, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) + return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) + } + var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses + switch h.Name { + case ":path": + hdr.Path = h.Value + case ":method": + hdr.Method = h.Value + case ":authority": + hdr.Authority = h.Value + case ":protocol": + hdr.Protocol = h.Value + case ":scheme": + hdr.Scheme = h.Value + case ":status": + hdr.Status = h.Value + isResponsePseudoHeader = true + } + if isRequest && isResponsePseudoHeader { + return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) + } + if !isRequest && !isResponsePseudoHeader { + return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) } } else { if !httpguts.ValidHeaderFieldName(h.Name) { - return nil, fmt.Errorf("invalid header field name: %q", h.Name) + return header{}, fmt.Errorf("invalid header field name: %q", h.Name) } readFirstRegularHeader = true - } - - switch h.Name { - case ":path": - path = h.Value - case ":method": - method = h.Value - case ":authority": - authority = h.Value - case ":protocol": - protocol = h.Value - case ":scheme": - scheme = h.Value - case "content-length": - contentLengthStr = h.Value - default: - if !h.IsPseudo() { - httpHeaders.Add(h.Name, h.Value) + switch h.Name { + case "content-length": + contentLengthStr = h.Value + default: + hdr.Headers.Add(h.Name, h.Value) } } } + if len(contentLengthStr) > 0 { + // use ParseUint instead of ParseInt, so that parsing fails on negative values + cl, err := strconv.ParseUint(contentLengthStr, 10, 63) + if err != nil { + return header{}, fmt.Errorf("invalid content length: %w", err) + } + hdr.Headers.Set("Content-Length", contentLengthStr) + hdr.ContentLength = int64(cl) + } + return hdr, nil +} +func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) { + hdr, err := parseHeaders(headerFields, true) + if err != nil { + return nil, err + } // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 - if len(httpHeaders["Cookie"]) > 0 { - httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) + if len(hdr.Headers["Cookie"]) > 0 { + hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; ")) } - isConnect := method == http.MethodConnect + isConnect := hdr.Method == http.MethodConnect // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 - isExtendedConnected := isConnect && protocol != "" + isExtendedConnected := isConnect && hdr.Protocol != "" if isExtendedConnected { - if scheme == "" || path == "" || authority == "" { + if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" { return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") } } else if isConnect { - if path != "" || authority == "" { // normal CONNECT + if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT return nil, errors.New(":path must be empty and :authority must not be empty") } - } else if len(path) == 0 || len(authority) == 0 || len(method) == 0 { + } else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 { return nil, errors.New(":path, :authority and :method must not be empty") } var u *url.URL var requestURI string - var err error + var protocol string if isConnect { u = &url.URL{} if isExtendedConnected { - u, err = url.ParseRequestURI(path) + u, err = url.ParseRequestURI(hdr.Path) if err != nil { return nil, err } } else { - u.Path = path + u.Path = hdr.Path } - u.Scheme = scheme - u.Host = authority - requestURI = authority + u.Scheme = hdr.Scheme + u.Host = hdr.Authority + requestURI = hdr.Authority + protocol = hdr.Protocol } else { protocol = "HTTP/3.0" - u, err = url.ParseRequestURI(path) - if err != nil { - return nil, err - } - requestURI = path - } - - var contentLength int64 - if len(contentLengthStr) > 0 { - // use ParseUint instead of ParseInt, so that parsing fails on negative values - cl, err := strconv.ParseUint(contentLengthStr, 10, 63) + u, err = url.ParseRequestURI(hdr.Path) if err != nil { return nil, fmt.Errorf("invalid content length: %w", err) } - httpHeaders.Set("Content-Length", contentLengthStr) - contentLength = int64(cl) + requestURI = hdr.Path } return &http.Request{ - Method: method, + Method: hdr.Method, URL: u, Proto: protocol, ProtoMajor: 3, ProtoMinor: 0, - Header: httpHeaders, + Header: hdr.Headers, Body: nil, - ContentLength: contentLength, - Host: authority, + ContentLength: hdr.ContentLength, + Host: hdr.Authority, RequestURI: requestURI, }, nil } diff --git a/http3/request_test.go b/http3/request_test.go index 87275c9c..f6a1a424 100644 --- a/http3/request_test.go +++ b/http3/request_test.go @@ -10,7 +10,7 @@ import ( ) var _ = Describe("Request", func() { - It("populates request", func() { + It("populates requests", func() { headers := []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic.clemente.io"}, @@ -89,6 +89,17 @@ var _ = Describe("Request", func() { Expect(err.Error()).To(ContainSubstring("invalid content length")) }) + It("rejects pseudo header fields defined for responses", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + {Name: ":status", Value: "404"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError("invalid request pseudo header: :status")) + }) + It("parses path with leading double slashes", func() { headers := []qpack.HeaderField{ {Name: ":path", Value: "//foo"},