http3: refactor header field processing into a separate function (#3971)

This commit is contained in:
Marten Seemann 2023-07-18 21:16:50 -07:00 committed by GitHub
parent 514df55288
commit ad16aa765d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 56 deletions

View file

@ -13,118 +13,147 @@ import (
"github.com/quic-go/qpack"
)
func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
var path, authority, method, protocol, scheme, contentLengthStr string
type header struct {
// Pseudo header fields defined in RFC 9114
Path string
Method string
Authority string
Scheme string
Status string
// for Extended connect
Protocol string
// parsed and deduplicated
ContentLength int64
// all non-pseudo headers
Headers http.Header
}
httpHeaders := http.Header{}
func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
hdr := header{Headers: make(http.Header, len(headers))}
var readFirstRegularHeader bool
var contentLengthStr string
for _, h := range headers {
// field names need to be lowercase, see section 4.2 of RFC 9114
if strings.ToLower(h.Name) != h.Name {
return nil, fmt.Errorf("header field is not lower-case: %s", h.Name)
return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name)
}
if !httpguts.ValidHeaderFieldValue(h.Value) {
return nil, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
}
if h.IsPseudo() {
if readFirstRegularHeader {
// all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114
return nil, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
}
var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses
switch h.Name {
case ":path":
hdr.Path = h.Value
case ":method":
hdr.Method = h.Value
case ":authority":
hdr.Authority = h.Value
case ":protocol":
hdr.Protocol = h.Value
case ":scheme":
hdr.Scheme = h.Value
case ":status":
hdr.Status = h.Value
isResponsePseudoHeader = true
}
if isRequest && isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name)
}
if !isRequest && !isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name)
}
} else {
if !httpguts.ValidHeaderFieldName(h.Name) {
return nil, fmt.Errorf("invalid header field name: %q", h.Name)
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
}
readFirstRegularHeader = true
}
switch h.Name {
case ":path":
path = h.Value
case ":method":
method = h.Value
case ":authority":
authority = h.Value
case ":protocol":
protocol = h.Value
case ":scheme":
scheme = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
switch h.Name {
case "content-length":
contentLengthStr = h.Value
default:
hdr.Headers.Add(h.Name, h.Value)
}
}
}
if len(contentLengthStr) > 0 {
// use ParseUint instead of ParseInt, so that parsing fails on negative values
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
if err != nil {
return header{}, fmt.Errorf("invalid content length: %w", err)
}
hdr.Headers.Set("Content-Length", contentLengthStr)
hdr.ContentLength = int64(cl)
}
return hdr, nil
}
func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) {
hdr, err := parseHeaders(headerFields, true)
if err != nil {
return nil, err
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
if len(hdr.Headers["Cookie"]) > 0 {
hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; "))
}
isConnect := method == http.MethodConnect
isConnect := hdr.Method == http.MethodConnect
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
isExtendedConnected := isConnect && protocol != ""
isExtendedConnected := isConnect && hdr.Protocol != ""
if isExtendedConnected {
if scheme == "" || path == "" || authority == "" {
if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" {
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
}
} else if isConnect {
if path != "" || authority == "" { // normal CONNECT
if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT
return nil, errors.New(":path must be empty and :authority must not be empty")
}
} else if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
} else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
var u *url.URL
var requestURI string
var err error
var protocol string
if isConnect {
u = &url.URL{}
if isExtendedConnected {
u, err = url.ParseRequestURI(path)
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, err
}
} else {
u.Path = path
u.Path = hdr.Path
}
u.Scheme = scheme
u.Host = authority
requestURI = authority
u.Scheme = hdr.Scheme
u.Host = hdr.Authority
requestURI = hdr.Authority
protocol = hdr.Protocol
} else {
protocol = "HTTP/3.0"
u, err = url.ParseRequestURI(path)
if err != nil {
return nil, err
}
requestURI = path
}
var contentLength int64
if len(contentLengthStr) > 0 {
// use ParseUint instead of ParseInt, so that parsing fails on negative values
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, fmt.Errorf("invalid content length: %w", err)
}
httpHeaders.Set("Content-Length", contentLengthStr)
contentLength = int64(cl)
requestURI = hdr.Path
}
return &http.Request{
Method: method,
Method: hdr.Method,
URL: u,
Proto: protocol,
ProtoMajor: 3,
ProtoMinor: 0,
Header: httpHeaders,
Header: hdr.Headers,
Body: nil,
ContentLength: contentLength,
Host: authority,
ContentLength: hdr.ContentLength,
Host: hdr.Authority,
RequestURI: requestURI,
}, nil
}

View file

@ -10,7 +10,7 @@ import (
)
var _ = Describe("Request", func() {
It("populates request", func() {
It("populates requests", func() {
headers := []qpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
@ -89,6 +89,17 @@ var _ = Describe("Request", func() {
Expect(err.Error()).To(ContainSubstring("invalid content length"))
})
It("rejects pseudo header fields defined for responses", func() {
headers := []qpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: ":status", Value: "404"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError("invalid request pseudo header: :status"))
})
It("parses path with leading double slashes", func() {
headers := []qpack.HeaderField{
{Name: ":path", Value: "//foo"},