diff --git a/http3/client.go b/http3/client.go index b2b7bbdc..849cc5d9 100644 --- a/http3/client.go +++ b/http3/client.go @@ -419,27 +419,13 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui return nil, newConnError(ErrCodeGeneralProtocolError, err) } + res, err := responseFromHeaders(hfs) + if err != nil { + return nil, newStreamError(ErrCodeMessageError, err) + } connState := conn.ConnectionState().TLS - res := &http.Response{ - Proto: "HTTP/3.0", - ProtoMajor: 3, - Header: http.Header{}, - TLS: &connState, - Request: req, - } - for _, hf := range hfs { - switch hf.Name { - case ":status": - status, err := strconv.Atoi(hf.Value) - if err != nil { - return nil, newStreamError(ErrCodeGeneralProtocolError, errors.New("malformed non-numeric status pseudo header")) - } - res.StatusCode = status - res.Status = hf.Value + " " + http.StatusText(status) - default: - res.Header.Add(hf.Name, hf.Value) - } - } + res.TLS = &connState + res.Request = req respBody := newResponseBody(hstr, conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. diff --git a/http3/client_test.go b/http3/client_test.go index 1c5ac6f8..babcb064 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -637,18 +637,6 @@ var _ = Describe("Client", func() { ) testDone := make(chan struct{}) - getHeadersFrame := func(headers map[string]string) []byte { - headerBuf := &bytes.Buffer{} - enc := qpack.NewEncoder(headerBuf) - for name, value := range headers { - Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed()) - } - Expect(enc.Close()).To(Succeed()) - b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) - b = append(b, headerBuf.Bytes()...) - return b - } - decodeHeader := func(str io.Reader) map[string]string { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) @@ -849,26 +837,6 @@ var _ = Describe("Client", func() { Eventually(closed).Should(BeClosed()) }) - It("sets the Content-Length", func() { - done := make(chan struct{}) - b := getHeadersFrame(map[string]string{ - ":status": "200", - "Content-Length": "1337", - }) - b = (&dataFrame{Length: 0x6}).Append(b) - b = append(b, []byte("foobar")...) - r := bytes.NewReader(b) - str.EXPECT().Close().Do(func() { close(done) }) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors - // the response body is sent asynchronously, while already reading the response - str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() - req, err := cl.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).ToNot(HaveOccurred()) - Expect(req.ContentLength).To(BeEquivalentTo(1337)) - Eventually(done).Should(BeClosed()) - }) - It("closes the connection when the first frame is not a HEADERS frame", func() { b := (&dataFrame{Length: 0x42}).Append(nil) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) @@ -881,6 +849,24 @@ var _ = Describe("Client", func() { Eventually(closed).Should(BeClosed()) }) + It("cancels the stream when parsing the headers fails", func() { + headerBuf := &bytes.Buffer{} + enc := qpack.NewEncoder(headerBuf) + Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header + Expect(enc.Close()).To(Succeed()) + b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) + b = append(b, headerBuf.Bytes()...) + + r := bytes.NewReader(b) + str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + closed := make(chan struct{}) + str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(HaveOccurred()) + Eventually(closed).Should(BeClosed()) + }) + It("cancels the stream when the HEADERS frame is too large", func() { b := (&headersFrame{Length: 1338}).Append(nil) r := bytes.NewReader(b) diff --git a/http3/request.go b/http3/headers.go similarity index 89% rename from http3/request.go rename to http3/headers.go index 9fb369c7..edad2c80 100644 --- a/http3/request.go +++ b/http3/headers.go @@ -164,3 +164,23 @@ func hostnameFromRequest(req *http.Request) string { } return "" } + +func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) { + hdr, err := parseHeaders(headerFields, false) + if err != nil { + return nil, err + } + rsp := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: hdr.Headers, + ContentLength: hdr.ContentLength, + } + status, err := strconv.Atoi(hdr.Status) + if err != nil { + return nil, fmt.Errorf("invalid status code: %w", err) + } + rsp.StatusCode = status + rsp.Status = hdr.Status + " " + http.StatusText(status) + return rsp, nil +} diff --git a/http3/request_test.go b/http3/headers_test.go similarity index 84% rename from http3/request_test.go rename to http3/headers_test.go index f6a1a424..ac09e932 100644 --- a/http3/request_test.go +++ b/http3/headers_test.go @@ -262,3 +262,51 @@ var _ = Describe("Request", func() { }) }) }) + +var _ = Describe("Response", func() { + It("populates responses", func() { + headers := []qpack.HeaderField{ + {Name: ":status", Value: "200"}, + {Name: "content-length", Value: "42"}, + } + rsp, err := responseFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(rsp.Proto).To(Equal("HTTP/3.0")) + Expect(rsp.ProtoMajor).To(Equal(3)) + Expect(rsp.ProtoMinor).To(BeZero()) + Expect(rsp.ContentLength).To(Equal(int64(42))) + Expect(rsp.Header).To(HaveLen(1)) + Expect(rsp.Header.Get("Content-Length")).To(Equal("42")) + Expect(rsp.Body).To(BeNil()) + Expect(rsp.StatusCode).To(BeEquivalentTo(200)) + Expect(rsp.Status).To(Equal("200 OK")) + }) + + It("rejects pseudo header fields after regular header fields", func() { + headers := []qpack.HeaderField{ + {Name: "content-length", Value: "42"}, + {Name: ":status", Value: "200"}, + } + _, err := responseFromHeaders(headers) + Expect(err).To(MatchError("received pseudo header :status after a regular header field")) + }) + + It("rejects invalid status codes", func() { + headers := []qpack.HeaderField{ + {Name: ":status", Value: "foobar"}, + {Name: "content-length", Value: "42"}, + } + _, err := responseFromHeaders(headers) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("invalid status code")) + }) + + It("rejects pseudo header fields defined for requests", func() { + headers := []qpack.HeaderField{ + {Name: ":status", Value: "404"}, + {Name: ":method", Value: "GET"}, + } + _, err := responseFromHeaders(headers) + Expect(err).To(MatchError("invalid response pseudo header: :method")) + }) +})