implement HTTP/3

This commit is contained in:
Marten Seemann 2019-03-11 15:06:23 +09:00
parent 1325909ab7
commit 4f6d0e651a
43 changed files with 2511 additions and 2540 deletions

View file

@ -1,6 +1,6 @@
run: run:
skip-files: skip-files:
- h2quic/response_writer_closenotifier.go - http3/response_writer_closenotifier.go
- internal/handshake/unsafe_test.go - internal/handshake/unsafe_test.go
linters-settings: linters-settings:

View file

@ -34,11 +34,7 @@ Running tests:
go test ./... go test ./...
### HTTP mapping ### QUIC without HTTP/3
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go). Take a look at [this echo example](example/echo/echo.go).
@ -50,16 +46,16 @@ See the [example server](example/main.go). Starting a QUIC server is very simila
```go ```go
http.Handle("/", http.FileServer(http.Dir(wwwDir))) http.Handle("/", http.FileServer(http.Dir(wwwDir)))
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
``` ```
### As a client ### As a client
See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`. See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
```go ```go
http.Client{ http.Client{
Transport: &h2quic.RoundTripper{}, Transport: &http3.RoundTripper{},
} }
``` ```

View file

@ -5,8 +5,6 @@ coverage:
- streams_map_incoming_uni.go - streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go - streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go - streams_map_outgoing_uni.go
- h2quic/gzipreader.go
- h2quic/response.go
- internal/ackhandler/packet_linkedlist.go - internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go - internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go - internal/utils/packetinterval_linkedlist.go

View file

@ -8,7 +8,7 @@ import (
"net/http" "net/http"
"sync" "sync"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
@ -28,7 +28,7 @@ func main() {
} }
logger.SetLogTimeFormat("") logger.SetLogTimeFormat("")
roundTripper := &h2quic.RoundTripper{ roundTripper := &http3.RoundTripper{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(), RootCAs: testdata.GetRootCA(),
}, },

View file

@ -15,7 +15,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
@ -135,9 +135,9 @@ func main() {
var err error var err error
if *tcp { if *tcp {
certFile, keyFile := testdata.GetCertificatePaths() certFile, keyFile := testdata.GetCertificatePaths()
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil) err = http3.ListenAndServe(bCap, certFile, keyFile, nil)
} else { } else {
server := h2quic.Server{ server := http3.Server{
Server: &http.Server{Addr: bCap}, Server: &http.Server{Addr: bCap},
} }
err = server.ListenAndServeTLS(testdata.GetCertificatePaths()) err = server.ListenAndServeTLS(testdata.GetCertificatePaths())

3
go.mod
View file

@ -5,9 +5,10 @@ go 1.12
require ( require (
github.com/cheekybits/genny v1.0.0 github.com/cheekybits/genny v1.0.0
github.com/golang/mock v1.2.0 github.com/golang/mock v1.2.0
github.com/marten-seemann/qpack v0.1.0
github.com/marten-seemann/qtls v0.2.3 github.com/marten-seemann/qtls v0.2.3
github.com/onsi/ginkgo v1.7.0 github.com/onsi/ginkgo v1.7.0
github.com/onsi/gomega v1.4.3 github.com/onsi/gomega v1.4.3
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7
) )

4
go.sum
View file

@ -8,6 +8,8 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg=
github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI=
github.com/marten-seemann/qtls v0.2.3 h1:0yWJ43C62LsZt08vuQJDK1uC1czUc3FJeCLPoNAI4vA= github.com/marten-seemann/qtls v0.2.3 h1:0yWJ43C62LsZt08vuQJDK1uC1czUc3FJeCLPoNAI4vA=
github.com/marten-seemann/qtls v0.2.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk= github.com/marten-seemann/qtls v0.2.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@ -19,6 +21,8 @@ golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 h1:jsG6UpNLt9iAsb0S2AGW28
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7 h1:Qe/u+eY379X4He4GBMFZYu3pmh1ML5yT1aL1ndNM1zQ=
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View file

@ -1,311 +0,0 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type roundTripperOpts struct {
DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
}
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{KeepAlive: true}
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger.WithPrefix("client"),
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
// once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStreamSync()
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream()
return nil
}
func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.NoError {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
// stop all running request
close(c.headerErrored)
}
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync()
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
c.mutex.Lock()
c.responses[dataStream.StreamID()] = responseChan
c.mutex.Unlock()
var requestedGzip bool
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.closeWithError(c.headerErr)
return nil, c.headerErr
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = &responseBody{dataStream}
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
res.Uncompressed = true
}
}
res.Request = req
return res, nil
}
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
func (c *client) closeWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
}
// Close closes the client
func (c *client) Close() error {
if c.session == nil {
return nil
}
return c.session.Close()
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}

View file

@ -1,640 +0,0 @@
package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
EventuallyWithOffset(0, func() bool {
client.mutex.Lock()
defer client.mutex.Unlock()
_, ok := client.responses[id]
return ok
}).Should(BeTrue())
rspChan := client.responses[5]
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
rspChan <- rsp
}
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
session = newMockSession()
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
// fmt.Println("done")
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("uses the custom dialer, if provided", func() {
var tlsCfg *tls.Config
var qCfg *quic.Config
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
tlsCfg = tlsCfgP
qCfg = cfg
return session, nil
}
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
Expect(qCfg).To(Equal(client.config))
Expect(tlsCfg).To(Equal(client.tlsConf))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
close(done)
}()
_, err = client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Eventually(done).Should(BeClosed())
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, r)
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
return mhframe
}
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
fields := make(map[string]string)
for _, hf := range f.Fields {
fields[hf.Name] = hf.Value
}
return fields
}
BeforeEach(func() {
var err error
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("does a request", func() {
teapot := &http.Response{
Status: "418 I'm a teapot",
StatusCode: 418,
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).To(Equal(teapot))
Expect(rsp.Body).To(BeAssignableToTypeOf(&responseBody{}))
Expect(rsp.Body.(*responseBody).Stream).To(Equal(dataStream))
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Request).To(Equal(request))
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
injectResponse(5, teapot)
Expect(client.headerErrored).ToNot(BeClosed())
Eventually(done).Should(BeClosed())
})
It("errors if a request without a body is canceled", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.canceledRead).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled after the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
time.Sleep(10 * time.Millisecond)
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.canceledRead).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled before the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
cancel()
time.Sleep(10 * time.Millisecond)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(dataStream.canceledRead).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
_, err := client.RoundTrip(request)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InternalError))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
session.blockOpenStreamSync = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
client.Close()
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("quic http2: unsupported scheme"))
})
It("adds the port for request URLs without one", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
})
It("sets the EndStream header for requests without a body", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("sets the EndStream header to false for requests with a body", func() {
request.Body = &mockBody{}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("requests containing a Body", func() {
var requestBody []byte
var response *http.Response
BeforeEach(func() {
requestBody = []byte("request body")
body := &mockBody{}
body.SetData(requestBody)
request.Body = body
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
// fake a handshake
client.dialOnce.Do(func() {})
session.streamsToOpen = []quic.Stream{dataStream}
})
It("sends a request", func() {
rspChan := make(chan *http.Response)
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
rspChan <- rsp
}()
injectResponse(5, response)
Eventually(rspChan).Should(Receive(Equal(response)))
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when reading the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).readErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when closing the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).closeErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
})
It("adds the gzip header to requests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
data := make([]byte, 6)
_, err = io.ReadFull(rsp.Body, data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
close(done)
}()
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
close(dataStream.unblockRead)
Eventually(done).Should(BeClosed())
})
It("doesn't add gzip if the header disable it", func() {
client.opts.DisableCompression = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).ToNot(HaveKey("accept-encoding"))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
data := make([]byte, 11)
rsp.Body.Read(data)
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("not gzipped")))
close(done)
}()
dataStream.dataToRead.Write([]byte("not gzipped"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
request.Header.Add("accept-encoding", "gzip")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data := make([]byte, 12)
_, err = rsp.Body.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("gzipped data")))
close(done)
}()
dataStream.dataToRead.Write([]byte("gzipped data"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
})
Context("handling the header stream", func() {
var h2framer *http2.Framer
BeforeEach(func() {
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
client.responses[23] = make(chan *http.Response)
})
It("reads header values from a response", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
headerStream.dataToRead.Write(data)
go client.handleHeaderStream()
var rsp *http.Response
Eventually(client.responses[23]).Should(Receive(&rsp))
Expect(rsp).ToNot(BeNil())
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
Expect(rsp.Status).To(Equal("302 Found"))
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
})
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InternalError, "not a headers frame")))
})
It("errors if it can't read the HPACK encoded header fields", func() {
h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 23,
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})
})

View file

@ -1,35 +0,0 @@
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

View file

@ -1,13 +0,0 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestH2quic(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "H2quic Suite")
}

View file

@ -1,29 +0,0 @@
package h2quic
import (
"io"
quic "github.com/lucas-clemente/quic-go"
)
type requestBody struct {
requestRead bool
dataStream quic.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream quic.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}

View file

@ -1,39 +0,0 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request body", func() {
var (
stream *mockStream
rb *requestBody
)
BeforeEach(func() {
stream = &mockStream{}
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
rb = newRequestBody(stream)
})
It("reads from the stream", func() {
b := make([]byte, 10)
n, _ := stream.Read(b)
Expect(n).To(Equal(6))
Expect(b[0:6]).To(Equal([]byte("foobar")))
})
It("saves if the stream was read from", func() {
Expect(rb.requestRead).To(BeFalse())
rb.Read(make([]byte, 1))
Expect(rb.requestRead).To(BeTrue())
})
It("doesn't close the stream when closing the request body", func() {
Expect(stream.closed).To(BeFalse())
err := rb.Close()
Expect(err).ToNot(HaveOccurred())
Expect(stream.closed).To(BeFalse())
})
})

View file

@ -1,121 +0,0 @@
package h2quic
import (
"net/http"
"net/url"
"golang.org/x/net/http2/hpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
It("populates request", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "content-length", Value: "42"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Method).To(Equal("GET"))
Expect(req.URL.Path).To(Equal("/foo"))
Expect(req.Proto).To(Equal("HTTP/2.0"))
Expect(req.ProtoMajor).To(Equal(2))
Expect(req.ProtoMinor).To(Equal(0))
Expect(req.ContentLength).To(Equal(int64(42)))
Expect(req.Header).To(BeEmpty())
Expect(req.Body).To(BeNil())
Expect(req.Host).To(Equal("quic.clemente.io"))
Expect(req.RequestURI).To(Equal("/foo"))
Expect(req.TLS).ToNot(BeNil())
})
It("concatenates the cookie headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cookie", Value: "cookie1=foobar1"},
{Name: "cookie", Value: "cookie2=foobar2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
}))
})
It("handles other headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cache-control", Value: "max-age=0"},
{Name: "duplicate-header", Value: "1"},
{Name: "duplicate-header", Value: "2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cache-Control": []string{"max-age=0"},
"Duplicate-Header": []string{"1", "2"},
}))
})
It("errors with missing path", func() {
headers := []hpack.HeaderField{
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing method", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing authority", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
Context("extracting the hostname from a request", func() {
var url *url.URL
BeforeEach(func() {
var err error
url, err = url.Parse("https://quic.clemente.io:1337")
Expect(err).ToNot(HaveOccurred())
})
It("uses req.URL.Host", func() {
req := &http.Request{URL: url}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("uses req.URL.Host even if req.Host is available", func() {
req := &http.Request{
Host: "www.example.org",
URL: url,
}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("returns an empty hostname if nothing is set", func() {
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
})
})
})

View file

@ -1,203 +0,0 @@
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream quic.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
logger: logger,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
}
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}

View file

@ -1,114 +0,0 @@
package h2quic
import (
"bytes"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
var (
rw *requestWriter
headerStream *mockStream
decoder *hpack.Decoder
)
BeforeEach(func() {
headerStream = &mockStream{}
rw = newRequestWriter(headerStream, utils.DefaultLogger)
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
})
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
framer := http2.NewFramer(nil, bytes.NewReader(p))
frame, err := framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
headerFrame := frame.(*http2.HeadersFrame)
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
values := make(map[string]string)
for _, headerField := range fields {
values[headerField.Name] = headerField.Value
}
return headerFrame, values
}
It("writes a GET request", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
Expect(headerFrame.HasPriority()).To(BeTrue())
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("sets the EndStream header", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeTrue())
})
It("doesn't set the EndStream header, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, false, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeFalse())
})
It("requests gzip compression, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, true)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("writes a POST request", func() {
form := url.Values{}
form.Add("foo", "bar")
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 5, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
contentLength, err := strconv.Atoi(headerFields["content-length"])
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
})
It("sends cookies", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
cookie1 := &http.Cookie{
Name: "Cookie #1",
Value: "Value #1",
}
cookie2 := &http.Cookie{
Name: "Cookie #2",
Value: "Value #2",
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
rw.WriteRequest(req, 11, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
})
})

View file

@ -1,95 +0,0 @@
package h2quic
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
// TODO: handle statusCode == 100
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
}
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}

View file

@ -1,163 +0,0 @@
package h2quic
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockStream struct {
id protocol.StreamID
dataToRead bytes.Buffer
dataWritten bytes.Buffer
canceledRead bool
canceledWrite bool
closed bool
remoteClosed bool
unblockRead chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
}
var _ quic.Stream = &mockStream{}
func newMockStream(id protocol.StreamID) *mockStream {
s := &mockStream{
id: id,
unblockRead: make(chan struct{}),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) CancelRead(quic.ErrorCode) { s.canceledRead = true }
func (s *mockStream) CancelWrite(quic.ErrorCode) { s.canceledWrite = true }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx }
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, io.EOF
}
return n, nil // never return an EOF
}
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
var _ = Describe("Response Writer", func() {
var (
w *responseWriter
headerStream *mockStream
dataStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
dataStream = &mockStream{}
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
})
decodeHeaderFields := func() map[string][]string {
fields := make(map[string][]string)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
hframe := frame.(*http2.HeadersFrame)
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
for _, p := range mhframe.Fields {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
It("writes status", func() {
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
})
It("writes headers", func() {
w.Header().Add("content-length", "42")
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
})
It("writes multiple headers with the same name", func() {
const cookie1 = "test1=1; Max-Age=7200; path=/"
const cookie2 = "test2=2; Max-Age=7200; path=/"
w.Header().Add("set-cookie", cookie1)
w.Header().Add("set-cookie", cookie2)
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKey("set-cookie"))
cookies := fields["set-cookie"]
Expect(cookies).To(ContainElement(cookie1))
Expect(cookies).To(ContainElement(cookie2))
})
It("writes data", func() {
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 200 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("writes data after WriteHeader is called", func() {
w.WriteHeader(http.StatusTeapot)
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 418 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("does not WriteHeader() twice", func() {
w.WriteHeader(200)
w.WriteHeader(500)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("doesn't allow writes if the status code doesn't allow a body", func() {
w.WriteHeader(304)
n, err := w.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
})
})

View file

@ -1,536 +0,0 @@
package h2quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockSession struct {
closed bool
closedWithError error
dataStream quic.Stream
streamToAccept quic.Stream
streamsToOpen []quic.Stream
blockOpenStreamSync bool
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
streamOpenErr error
ctx context.Context
ctxCancel context.CancelFunc
}
func newMockSession() *mockSession {
return &mockSession{blockOpenStreamChan: make(chan struct{})}
}
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
return s.dataStream, nil
}
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
func (s *mockSession) OpenStream() (quic.Stream, error) {
if s.streamOpenErr != nil {
return nil, s.streamOpenErr
}
str := s.streamsToOpen[0]
s.streamsToOpen = s.streamsToOpen[1:]
return str, nil
}
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
if s.blockOpenStreamSync {
<-s.blockOpenStreamChan
}
return s.OpenStream()
}
func (s *mockSession) Close() error {
s.ctxCancel()
if !s.closed {
close(s.blockOpenStreamChan)
}
s.closed = true
return nil
}
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
s.closedWithError = e
return s.Close()
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}
func (s *mockSession) Context() context.Context {
return s.ctx
}
func (s *mockSession) ConnectionState() tls.ConnectionState { panic("not implemented") }
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (
s *Server
session *mockSession
dataStream *mockStream
origQuicListenAddr = quicListenAddr
)
BeforeEach(func() {
s = &Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
logger: utils.DefaultLogger,
}
dataStream = newMockStream(0)
close(dataStream.unblockRead)
session = newMockSession()
session.dataStream = dataStream
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
origQuicListenAddr = quicListenAddr
})
AfterEach(func() {
quicListenAddr = origQuicListenAddr
})
Context("handling requests", func() {
var (
h2framer *http2.Framer
hpackDecoder *hpack.Decoder
headerStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
hpackDecoder = hpack.NewDecoder(4096, nil)
h2framer = http2.NewFramer(nil, headerStream)
})
It("handles a sample GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.canceledRead).To(BeFalse())
})
It("returns 200 with an empty handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
})
It("correctly handles a panicking handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
})
It("resets the dataStream when client sends a body in GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeFalse())
})
It("resets the dataStream when the body of POST request is not read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("handles a request for which the client immediately resets the data stream", func() {
session.dataStream = nil
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
})
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("closes the dataStream if the body of POST request was read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
// read the request body
b := make([]byte, 1000)
n, _ := r.Body.Read(b)
Expect(n).ToNot(BeZero())
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
dataStream.dataToRead.Write([]byte("foo=bar"))
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.canceledRead).To(BeFalse())
})
It("ignores PRIORITY frames", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
buf := &bytes.Buffer{}
framer := http2.NewFramer(buf, nil)
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).ToNot(BeEmpty())
headerStream.dataToRead.Write(buf.Bytes())
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).ToNot(HaveOccurred())
Consistently(handlerCalled).ShouldNot(BeClosed())
Expect(dataStream.canceledRead).To(BeFalse())
Expect(dataStream.closed).To(BeFalse())
})
It("errors when non-header frames are received", func() {
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
'f', 'o', 'o', 'b', 'a', 'r',
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: expected a header frame"))
})
It("Cancels the request context when the datstream is closed", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := r.Context().Err()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("context canceled"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
dataStream.Close()
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.canceledRead).To(BeFalse())
})
})
It("handles the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
It("closes the connection if it encounters an error on the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
Eventually(func() bool { return session.closed }).Should(BeTrue())
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InternalError, "cannot read frame")))
})
It("supports closing after first request", func() {
s.CloseAfterFirstRequest = true
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
Expect(session.closed).To(BeFalse())
go s.handleHeaderStream(session)
Eventually(func() bool { return session.closed }).Should(BeTrue())
})
It("uses the default handler as fallback", func() {
var handlerCalled bool
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
}))
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
var versionsAsString []string
for _, v := range versions {
versionsAsString = append(versionsAsString, v.ToAltSvc())
}
return http.Header{
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
}
}
BeforeEach(func() {
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
expected = getExpectedHeader(protocol.SupportedVersions)
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with full addr", func() {
s.Server.Addr = "127.0.0.1:443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with string port", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("works multiple times", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
hdr = http.Header{}
err = s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
})
It("should error when ListenAndServe is called with s.Server nil", func() {
err := (&Server{}).ListenAndServe()
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() {
err := (&Server{}).Close()
Expect(err).NotTo(HaveOccurred())
})
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
err := serv.ListenAndServe()
Expect(err).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServe()
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
It("uses the quic.Config to start the quic server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
s.QuicConfig = conf
go s.ListenAndServe()
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
})
})
Context("ListenAndServeTLS", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
err := s.Close()
Expect(err).NotTo(HaveOccurred())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
})
It("closes gracefully", func() {
err := s.CloseGracefully(0)
Expect(err).NotTo(HaveOccurred())
})
It("errors when listening fails", func() {
testErr := errors.New("listen error")
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
return nil, testErr
}
fullpem, privkey := testdata.GetCertificatePaths()
err := ListenAndServeQUIC("", fullpem, privkey, nil)
Expect(err).To(MatchError(testErr))
})
})

68
http3/body.go Normal file
View file

@ -0,0 +1,68 @@
package http3
import (
"errors"
"io"
)
// The body of a http.Request or http.Response.
type body struct {
str io.ReadCloser
isRequest bool
bytesRemainingInFrame uint64
}
var _ io.ReadCloser = &body{}
func newRequestBody(str io.ReadCloser) *body {
return &body{
str: str,
isRequest: true,
}
}
func newResponseBody(str io.ReadCloser) *body {
return &body{str: str}
}
func (r *body) Read(b []byte) (int, error) {
if r.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(r.str)
if err != nil {
return 0, err
}
switch f := frame.(type) {
case *headersFrame:
// skip HEADERS frames
continue
case *dataFrame:
r.bytesRemainingInFrame = f.Length
break parseLoop
default:
return 0, errors.New("unexpected frame")
}
}
}
var n int
var err error
if r.bytesRemainingInFrame < uint64(len(b)) {
n, err = r.str.Read(b[:r.bytesRemainingInFrame])
} else {
n, err = r.str.Read(b)
}
r.bytesRemainingInFrame -= uint64(n)
return n, err
}
func (r *body) Close() error {
// quic.Stream.Close() closes the write side, not the read side
if r.isRequest {
return nil
}
return r.str.Close()
}

150
http3/body_test.go Normal file
View file

@ -0,0 +1,150 @@
package http3
import (
"bytes"
"io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type closingBuffer struct {
*bytes.Buffer
closed bool
}
func (b *closingBuffer) Close() error { b.closed = true; return nil }
type bodyType uint8
const (
bodyTypeRequest bodyType = iota
bodyTypeResponse
)
var _ = Describe("Body", func() {
var rb *body
var buf *bytes.Buffer
getDataFrame := func(data []byte) []byte {
b := &bytes.Buffer{}
(&dataFrame{Length: uint64(len(data))}).Write(b)
b.Write(data)
return b.Bytes()
}
BeforeEach(func() {
buf = &bytes.Buffer{}
})
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
bodyType := bt
BeforeEach(func() {
cb := &closingBuffer{Buffer: buf}
switch bodyType {
case bodyTypeRequest:
rb = newRequestBody(cb)
case bodyTypeResponse:
rb = newResponseBody(cb)
}
})
It("reads DATA frames in a single run", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("reads DATA frames in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 3)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b).To(Equal([]byte("bar")))
})
It("reads DATA frames into too large buffers", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 10)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b[:n]).To(Equal([]byte("foobar")))
})
It("reads DATA frames into too large buffers, in multiple runs", func() {
buf.Write(getDataFrame([]byte("foobar")))
b := make([]byte, 4)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(b).To(Equal([]byte("foob")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
Expect(b[:n]).To(Equal([]byte("ar")))
})
It("reads multiple DATA frames", func() {
buf.Write(getDataFrame([]byte("foo")))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("foo")))
n, err = rb.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(b[:n]).To(Equal([]byte("bar")))
})
It("skips HEADERS frames", func() {
buf.Write(getDataFrame([]byte("foo")))
(&headersFrame{Length: 10}).Write(buf)
buf.Write(make([]byte, 10))
buf.Write(getDataFrame([]byte("bar")))
b := make([]byte, 6)
n, err := io.ReadFull(rb, b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(6))
Expect(b).To(Equal([]byte("foobar")))
})
It("errors when it can't parse the frame", func() {
buf.Write([]byte("invalid"))
_, err := rb.Read([]byte{0})
Expect(err).To(HaveOccurred())
})
It("errors on unexpected frames", func() {
(&settingsFrame{}).Write(buf)
_, err := rb.Read([]byte{0})
Expect(err).To(MatchError("unexpected frame"))
})
}
It("closes requests", func() {
cb := &closingBuffer{Buffer: buf}
rb := newRequestBody(cb)
Expect(rb.Close()).To(Succeed())
Expect(cb.closed).To(BeFalse())
})
It("closes responses", func() {
cb := &closingBuffer{Buffer: buf}
rb := newResponseBody(cb)
Expect(rb.Close()).To(Succeed())
Expect(cb.closed).To(BeTrue())
})
})

182
http3/client.go Normal file
View file

@ -0,0 +1,182 @@
package http3
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"sync"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
)
const defaultUserAgent = "quic-go HTTP/3"
var defaultQuicConfig = &quic.Config{KeepAlive: true}
var dialAddr = quic.DialAddr
type roundTripperOpts struct {
DisableCompression bool
}
// client is a HTTP3 client doing requests
type client struct {
tlsConf *tls.Config
config *quic.Config
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
handshakeErr error
requestWriter *requestWriter
decoder *qpack.Decoder
hostname string
session quic.Session
logger utils.Logger
}
func newClient(
hostname string,
tlsConf *tls.Config,
_ *roundTripperOpts, // TODO: implement gzip compression
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
if tlsConf == nil {
tlsConf = &tls.Config{}
}
tlsConf.NextProtos = []string{"h3-19"}
if quicConfig == nil {
quicConfig = defaultQuicConfig
}
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
logger := utils.DefaultLogger.WithPrefix("h3 client")
return &client{
hostname: authorityAddr("https", hostname),
tlsConf: tlsConf,
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: quicConfig,
dialer: dialer,
logger: logger,
}
}
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
go func() {
if err := c.setupSession(); err != nil {
c.session.CloseWithError(quic.ErrorCode(errorInternalError), err)
}
}()
// TODO: send a SETTINGS frame
return nil
}
func (c *client) setupSession() error {
// open the control stream
str, err := c.session.OpenUniStreamSync()
if err != nil {
return err
}
buf := &bytes.Buffer{}
// write the type byte
buf.Write([]byte{0x0})
// send the SETTINGS frame
(&settingsFrame{}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
func (c *client) Close() error {
return c.session.Close()
}
// Roundtrip executes a request and returns a response
// TODO: handle request cancelations
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
return nil, errors.New("http3: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
str, err := c.session.OpenStreamSync()
if err != nil {
return nil, err
}
if err := c.requestWriter.WriteRequest(str, req); err != nil {
return nil, err
}
frame, err := parseNextFrame(str)
if err != nil {
return nil, err
}
hf, ok := frame.(*headersFrame)
if !ok {
return nil, errors.New("not a HEADERS frame")
}
// TODO: check size
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return nil, err
}
hfs, err := c.decoder.DecodeFull(headerBlock)
if err != nil {
return nil, err
}
res := &http.Response{
Proto: "HTTP/3",
ProtoMajor: 3,
Header: http.Header{},
Body: newResponseBody(&responseBody{str}),
}
for _, hf := range hfs {
switch hf.Name {
case ":status":
status, err := strconv.Atoi(hf.Value)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
res.StatusCode = status
res.Status = hf.Value + " " + http.StatusText(status)
default:
res.Header.Add(hf.Name, hf.Value)
}
}
return res, nil
}

272
http3/client_test.go Normal file
View file

@ -0,0 +1,272 @@
package http3
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net/http"
"time"
"github.com/golang/mock/gomock"
quic "github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
req *http.Request
origDialAddr = dialAddr
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("uses the default QUIC config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(_ string, _ *tls.Config, quicConf *quic.Config) (quic.Session, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("uses the TLS config and QUIC config", func() {
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{IdleTimeout: time.Nanosecond}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
var dialAddrCalled bool
dialAddr = func(
hostname string,
tlsConfP *tls.Config,
quicConfP *quic.Config,
) (quic.Session, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP).To(Equal(tlsConf))
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("uses the custom dialer, if provided", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{IdleTimeout: 1337 * time.Second}
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) {
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP).To(Equal(tlsConf))
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
dialerCalled = true
return nil, testErr
}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("stream open error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockSession(mockCtrl)
session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
Context("Doing requests", func() {
var (
request *http.Request
str *mockquic.MockStream
sess *mockquic.MockSession
)
decodeHeader := func(str io.Reader) map[string]string {
fields := make(map[string]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = p.Value
}
return fields
}
BeforeEach(func() {
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write([]byte{0x0}).Return(1, nil).MaxTimes(1)
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockSession(mockCtrl)
sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return sess, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("sends a request", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue(":scheme", "https"))
Expect(hfs).To(HaveKeyWithValue(":method", "GET"))
Expect(hfs).To(HaveKeyWithValue(":authority", "quic.clemente.io:1337"))
Expect(hfs).To(HaveKeyWithValue(":path", "/file1.dat"))
})
It("returns a response", func() {
rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
sess.EXPECT().OpenStreamSync().Return(str, nil)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return rspBuf.Read(p)
}).AnyTimes()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3: unsupported scheme"))
})
})
Context("requests containing a Body", func() {
var strBuf *bytes.Buffer
BeforeEach(func() {
strBuf = &bytes.Buffer{}
sess.EXPECT().OpenStreamSync().Return(str, nil)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
})
It("sends a request", func() {
done := make(chan struct{})
str.EXPECT().Close().Do(func() { close(done) })
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
})
It("returns the error that occurred when reading the body", func() {
request.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{})
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) {
close(done)
})
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
})
})
})
})

135
http3/frames.go Normal file
View file

@ -0,0 +1,135 @@
package http3
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type byteReader interface {
io.ByteReader
io.Reader
}
type byteReaderImpl struct{ io.Reader }
func (br *byteReaderImpl) ReadByte() (byte, error) {
b := make([]byte, 1)
if _, err := br.Reader.Read(b); err != nil {
return 0, err
}
return b[0], nil
}
type frame interface{}
func parseNextFrame(b io.Reader) (frame, error) {
br, ok := b.(byteReader)
if !ok {
br = &byteReaderImpl{b}
}
t, err := utils.ReadVarInt(br)
if err != nil {
return nil, err
}
l, err := utils.ReadVarInt(br)
if err != nil {
return nil, err
}
switch t {
case 0x0:
return &dataFrame{Length: l}, nil
case 0x1:
return &headersFrame{Length: l}, nil
case 0x4:
return parseSettingsFrame(br, l)
case 0x2: // PRIORITY
fallthrough
case 0x3: // CANCEL_PUSH
fallthrough
case 0x5: // PUSH_PROMISE
fallthrough
case 0x7: // GOAWAY
fallthrough
case 0xd: // MAX_PUSH_ID
fallthrough
case 0xe: // DUPLICATE_PUSH
fallthrough
default:
// skip over unknown frames
if _, err := io.CopyN(ioutil.Discard, br, int64(l)); err != nil {
return nil, err
}
return parseNextFrame(b)
}
}
type dataFrame struct {
Length uint64
}
func (f *dataFrame) Write(b *bytes.Buffer) {
utils.WriteVarInt(b, 0x0)
utils.WriteVarInt(b, f.Length)
}
type headersFrame struct {
Length uint64
}
func (f *headersFrame) Write(b *bytes.Buffer) {
utils.WriteVarInt(b, 0x1)
utils.WriteVarInt(b, f.Length)
}
type settingsFrame struct {
settings map[uint64]uint64
}
func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
if l > 8*(1<<10) {
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
}
buf := make([]byte, l)
if _, err := io.ReadFull(r, buf); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
frame := &settingsFrame{settings: make(map[uint64]uint64)}
b := bytes.NewReader(buf)
for b.Len() > 0 {
id, err := utils.ReadVarInt(b)
if err != nil { // should not happen. We allocated the whole frame already.
return nil, err
}
val, err := utils.ReadVarInt(b)
if err != nil { // should not happen. We allocated the whole frame already.
return nil, err
}
if _, ok := frame.settings[id]; ok {
return nil, fmt.Errorf("duplicate setting: %d", id)
}
frame.settings[id] = val
}
return frame, nil
}
func (f *settingsFrame) Write(b *bytes.Buffer) {
utils.WriteVarInt(b, 0x4)
var l protocol.ByteCount
for id, val := range f.settings {
l += utils.VarIntLen(id) + utils.VarIntLen(val)
}
utils.WriteVarInt(b, uint64(l))
for id, val := range f.settings {
utils.WriteVarInt(b, id)
utils.WriteVarInt(b, val)
}
}

136
http3/frames_test.go Normal file
View file

@ -0,0 +1,136 @@
package http3
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Frames", func() {
appendVarInt := func(b []byte, val uint64) []byte {
buf := &bytes.Buffer{}
utils.WriteVarInt(buf, val)
return append(b, buf.Bytes()...)
}
It("skips unknown frames", func() {
data := appendVarInt(nil, 0xdeadbeef) // type byte
data = appendVarInt(data, 0x42)
data = append(data, make([]byte, 0x42)...)
buf := bytes.NewBuffer(data)
(&dataFrame{Length: 0x1234}).Write(buf)
frame, err := parseNextFrame(buf)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234)))
})
Context("DATA frames", func() {
It("parses", func() {
data := appendVarInt(nil, 0) // type byte
data = appendVarInt(data, 0x1337)
frame, err := parseNextFrame(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337)))
})
It("writes", func() {
buf := &bytes.Buffer{}
(&dataFrame{Length: 0xdeadbeef}).Write(buf)
frame, err := parseNextFrame(buf)
Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0xdeadbeef)))
})
})
Context("HEADERS frames", func() {
It("parses", func() {
data := appendVarInt(nil, 1) // type byte
data = appendVarInt(data, 0x1337)
frame, err := parseNextFrame(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337)))
})
It("writes", func() {
buf := &bytes.Buffer{}
(&headersFrame{Length: 0xdeadbeef}).Write(buf)
frame, err := parseNextFrame(buf)
Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
Expect(frame.(*headersFrame).Length).To(Equal(uint64(0xdeadbeef)))
})
})
Context("SETTINGS frames", func() {
It("parses", func() {
settings := appendVarInt(nil, 13)
settings = appendVarInt(settings, 37)
settings = appendVarInt(settings, 0xdead)
settings = appendVarInt(settings, 0xbeef)
data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...)
frame, err := parseNextFrame(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{}))
sf := frame.(*settingsFrame)
Expect(sf.settings).To(HaveKeyWithValue(uint64(13), uint64(37)))
Expect(sf.settings).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef)))
})
It("rejects duplicate settings", func() {
settings := appendVarInt(nil, 13)
settings = appendVarInt(settings, 37)
settings = appendVarInt(settings, 13)
settings = appendVarInt(settings, 38)
data := appendVarInt(nil, 4) // type byte
data = appendVarInt(data, uint64(len(settings)))
data = append(data, settings...)
_, err := parseNextFrame(bytes.NewReader(data))
Expect(err).To(MatchError("duplicate setting: 13"))
})
It("writes", func() {
sf := &settingsFrame{settings: map[uint64]uint64{
1: 2,
99: 999,
13: 37,
}}
buf := &bytes.Buffer{}
sf.Write(buf)
frame, err := parseNextFrame(buf)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(sf))
})
It("errors on EOF", func() {
sf := &settingsFrame{settings: map[uint64]uint64{
13: 37,
0xdeadbeef: 0xdecafbad,
}}
buf := &bytes.Buffer{}
sf.Write(buf)
data := buf.Bytes()
_, err := parseNextFrame(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
for i := range data {
b := make([]byte, i)
copy(b, data[:i])
_, err := parseNextFrame(bytes.NewReader(b))
Expect(err).To(MatchError(io.EOF))
}
})
})
})

View file

@ -3,6 +3,8 @@ package http3
import ( import (
"testing" "testing"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -11,3 +13,13 @@ func TestHttp3(t *testing.T) {
RegisterFailHandler(Fail) RegisterFailHandler(Fail)
RunSpecs(t, "HTTP/3 Suite") RunSpecs(t, "HTTP/3 Suite")
} }
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View file

@ -1,4 +1,4 @@
package h2quic package http3
import ( import (
"crypto/tls" "crypto/tls"
@ -8,10 +8,10 @@ import (
"strconv" "strconv"
"strings" "strings"
"golang.org/x/net/http2/hpack" "github.com/marten-seemann/qpack"
) )
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
var path, authority, method, contentLengthStr string var path, authority, method, contentLengthStr string
httpHeaders := http.Header{} httpHeaders := http.Header{}
@ -57,8 +57,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
return &http.Request{ return &http.Request{
Method: method, Method: method,
URL: u, URL: u,
Proto: "HTTP/2.0", Proto: "HTTP/3",
ProtoMajor: 2, ProtoMajor: 3,
ProtoMinor: 0, ProtoMinor: 0,
Header: httpHeaders, Header: httpHeaders,
Body: nil, Body: nil,

312
http3/request_writer.go Normal file
View file

@ -0,0 +1,312 @@
package http3
import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
)
type requestWriter struct {
mutex sync.Mutex
encoder *qpack.Encoder
headerBuf *bytes.Buffer
logger utils.Logger
}
func newRequestWriter(logger utils.Logger) *requestWriter {
headerBuf := &bytes.Buffer{}
encoder := qpack.NewEncoder(headerBuf)
return &requestWriter{
encoder: encoder,
headerBuf: headerBuf,
logger: logger,
}
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error {
headers, err := w.getHeaders(req)
if err != nil {
return err
}
if _, err := str.Write(headers); err != nil {
return err
}
// TODO: add support for trailers
if req.Body == nil {
str.Close()
return nil
}
// send the request body asynchronously
go func() {
if err := w.sendRequestBody(req.Body, str); err != nil {
w.logger.Errorf("Error writing request: %s", err)
return
}
str.Close()
}()
return nil
}
func (w *requestWriter) getHeaders(req *http.Request) ([]byte, error) {
w.mutex.Lock()
defer w.mutex.Unlock()
defer w.encoder.Close()
if err := w.encodeHeaders(req, false, "", actualContentLength(req)); err != nil {
return nil, err
}
buf := &bytes.Buffer{}
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
hf.Write(buf)
if _, err := io.Copy(buf, w.headerBuf); err != nil {
return nil, err
}
w.headerBuf.Reset()
return buf.Bytes(), nil
}
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
b := make([]byte, 8*1024)
for {
n, err := req.Read(b)
if err == io.EOF {
break
}
if err != nil {
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
return err
}
buf := &bytes.Buffer{}
(&dataFrame{Length: uint64(n)}).Write(buf)
if _, err := str.Write(buf.Bytes()); err != nil {
return err
}
if _, err := str.Write(b[:n]); err != nil {
return err
}
}
req.Close()
return nil
}
// copied from net/transport.go
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
f(":method", req.Method)
if req.Method != "CONNECT" {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
strings.EqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if strings.EqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
f("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", defaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
// TODO: check maximum header list size
// if hlSize > cc.peerMaxHeaderListSize {
// return errRequestHeaderListSize
// }
// trace := httptrace.ContextClientTrace(req.Context())
// traceHeaders := traceHasWroteHeaderField(trace)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name = strings.ToLower(name)
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
// if traceHeaders {
// traceWroteHeaderField(trace, name, value)
// }
})
return nil
}
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// *) a non-empty string starting with '/'
// *) the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}

View file

@ -0,0 +1,105 @@
package http3
import (
"bytes"
"io"
"net/http"
"strconv"
"github.com/marten-seemann/qpack"
"github.com/golang/mock/gomock"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request Writer", func() {
var (
rw *requestWriter
str *mockquic.MockStream
strBuf *bytes.Buffer
)
decode := func(str io.Reader) map[string]string {
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
decoder := qpack.NewDecoder(nil)
hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
values := make(map[string]string)
for _, hf := range hfs {
values[hf.Name] = hf.Value
}
return values
}
BeforeEach(func() {
rw = newRequestWriter(utils.DefaultLogger)
strBuf = &bytes.Buffer{}
str = mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
})
It("writes a GET request", func() {
str.EXPECT().Close()
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("writes a POST request", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
postData := bytes.NewReader([]byte("foobar"))
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
contentLength, err := strconv.Atoi(headerFields["content-length"])
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
Eventually(closed).Should(BeClosed())
frame, err := parseNextFrame(strBuf)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
})
It("sends cookies", func() {
str.EXPECT().Close()
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
cookie1 := &http.Cookie{
Name: "Cookie #1",
Value: "Value #1",
}
cookie2 := &http.Cookie{
Name: "Cookie #2",
Value: "Value #2",
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
Expect(rw.WriteRequest(str, req)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
})
})

View file

@ -1,4 +1,4 @@
package h2quic package http3
import ( import (
"io" "io"

View file

@ -1,7 +1,8 @@
package h2quic package http3
import ( import (
"bytes" "github.com/golang/mock/gomock"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -9,21 +10,17 @@ import (
var _ = Describe("Response Body", func() { var _ = Describe("Response Body", func() {
var ( var (
stream *mockStream stream *mockquic.MockStream
body *responseBody body *responseBody
) )
BeforeEach(func() { BeforeEach(func() {
stream = newMockStream(42) stream = mockquic.NewMockStream(mockCtrl)
body = &responseBody{stream} body = &responseBody{stream}
}) })
It("calls CancelRead when closing", func() { It("calls CancelRead when closing", func() {
stream.dataToRead = *bytes.NewBuffer([]byte("foobar")) stream.EXPECT().CancelRead(gomock.Any())
n, err := body.Read(make([]byte, 3))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(body.Close()).To(Succeed()) Expect(body.Close()).To(Succeed())
Expect(stream.canceledRead).To(BeTrue())
}) })
}) })

View file

@ -1,25 +1,18 @@
package h2quic package http3
import ( import (
"bytes" "bytes"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/net/http2" "github.com/marten-seemann/qpack"
"golang.org/x/net/http2/hpack"
) )
type responseWriter struct { type responseWriter struct {
dataStreamID protocol.StreamID stream io.Writer
dataStream quic.Stream
headerStream quic.Stream
headerStreamMutex *sync.Mutex
header http.Header header http.Header
status int // status code passed to WriteHeader status int // status code passed to WriteHeader
@ -28,20 +21,13 @@ type responseWriter struct {
logger utils.Logger logger utils.Logger
} }
func newResponseWriter( var _ http.ResponseWriter = &responseWriter{}
headerStream quic.Stream,
headerStreamMutex *sync.Mutex, func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter {
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
headerStream: headerStream, stream: stream,
headerStreamMutex: headerStreamMutex, logger: logger,
dataStream: dataStream,
dataStreamID: dataStreamID,
logger: logger,
} }
} }
@ -57,26 +43,23 @@ func (w *responseWriter) WriteHeader(status int) {
w.status = status w.status = status
var headers bytes.Buffer var headers bytes.Buffer
enc := hpack.NewEncoder(&headers) enc := qpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
for k, v := range w.header { for k, v := range w.header {
for index := range v { for index := range v {
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
} }
} }
buf := &bytes.Buffer{}
(&headersFrame{Length: uint64(headers.Len())}).Write(buf)
w.logger.Infof("Responding with %d", status) w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock() if _, err := w.stream.Write(buf.Bytes()); err != nil {
defer w.headerStreamMutex.Unlock() w.logger.Errorf("could not write headers frame: %s", err.Error())
h2framer := http2.NewFramer(w.headerStream, nil) }
err := h2framer.WriteHeaders(http2.HeadersFrameParam{ if _, err := w.stream.Write(headers.Bytes()); err != nil {
StreamID: uint32(w.dataStreamID), w.logger.Errorf("could not write header frame payload: %s", err.Error())
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
if err != nil {
w.logger.Errorf("could not write h2 header: %s", err.Error())
} }
} }
@ -87,7 +70,13 @@ func (w *responseWriter) Write(p []byte) (int, error) {
if !bodyAllowedForStatus(w.status) { if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed return 0, http.ErrBodyNotAllowed
} }
return w.dataStream.Write(p) df := &dataFrame{Length: uint64(len(p))}
buf := &bytes.Buffer{}
df.Write(buf)
if _, err := w.stream.Write(buf.Bytes()); err != nil {
return 0, err
}
return w.stream.Write(p)
} }
func (w *responseWriter) Flush() {} func (w *responseWriter) Flush() {}

View file

@ -1,4 +1,4 @@
package h2quic package http3
import "net/http" import "net/http"

View file

@ -0,0 +1,120 @@
package http3
import (
"bytes"
"io"
"net/http"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Response Writer", func() {
var (
rw *responseWriter
strBuf *bytes.Buffer
)
BeforeEach(func() {
strBuf = &bytes.Buffer{}
rw = newResponseWriter(strBuf, utils.DefaultLogger)
})
decodeHeader := func(str io.Reader) map[string][]string {
fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
getData := func(str io.Reader) []byte {
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
df := frame.(*dataFrame)
data := make([]byte, df.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
return data
}
It("writes status", func() {
rw.WriteHeader(http.StatusTeapot)
fields := decodeHeader(strBuf)
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
})
It("writes headers", func() {
rw.Header().Add("content-length", "42")
rw.WriteHeader(http.StatusTeapot)
fields := decodeHeader(strBuf)
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
})
It("writes multiple headers with the same name", func() {
const cookie1 = "test1=1; Max-Age=7200; path=/"
const cookie2 = "test2=2; Max-Age=7200; path=/"
rw.Header().Add("set-cookie", cookie1)
rw.Header().Add("set-cookie", cookie2)
rw.WriteHeader(http.StatusTeapot)
fields := decodeHeader(strBuf)
Expect(fields).To(HaveKey("set-cookie"))
cookies := fields["set-cookie"]
Expect(cookies).To(ContainElement(cookie1))
Expect(cookies).To(ContainElement(cookie2))
})
It("writes data", func() {
n, err := rw.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 200 on the header stream
fields := decodeHeader(strBuf)
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
// And foobar on the data stream
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
})
It("writes data after WriteHeader is called", func() {
rw.WriteHeader(http.StatusTeapot)
n, err := rw.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 418 on the header stream
fields := decodeHeader(strBuf)
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
// And foobar on the data stream
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
})
It("does not WriteHeader() twice", func() {
rw.WriteHeader(200)
rw.WriteHeader(500)
fields := decodeHeader(strBuf)
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("doesn't allow writes if the status code doesn't allow a body", func() {
rw.WriteHeader(304)
n, err := rw.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
})
})

View file

@ -1,4 +1,4 @@
package h2quic package http3
import ( import (
"crypto/tls" "crypto/tls"
@ -61,42 +61,42 @@ type RoundTripOpt struct {
var _ roundTripCloser = &RoundTripper{} var _ roundTripCloser = &RoundTripper{}
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available") var ErrNoCachedConn = errors.New("http3: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil { if req.URL == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL") return nil, errors.New("http3: nil Request.URL")
} }
if req.URL.Host == "" { if req.URL.Host == "" {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL") return nil, errors.New("http3: no Host in request URL")
} }
if req.Header == nil { if req.Header == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header") return nil, errors.New("http3: nil Request.Header")
} }
if req.URL.Scheme == "https" { if req.URL.Scheme == "https" {
for k, vv := range req.Header { for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) { if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k) return nil, fmt.Errorf("http3: invalid http header field name %q", k)
} }
for _, v := range vv { for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) { if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
} }
} }
} }
} else { } else {
closeRequestBody(req) closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
} }
if req.Method != "" && !validMethod(req.Method) { if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req) closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method) return nil, fmt.Errorf("http3: invalid method %q", req.Method)
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))

View file

@ -1,4 +1,4 @@
package h2quic package http3
import ( import (
"bytes" "bytes"
@ -8,7 +8,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/golang/mock/gomock"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -34,6 +36,9 @@ type mockBody struct {
closed bool closed bool
} }
// make sure the mockBody can be used as a http.Request.Body
var _ io.ReadCloser = &mockBody{}
func (m *mockBody) Read(p []byte) (int, error) { func (m *mockBody) Read(p []byte) (int, error) {
if m.readErr != nil { if m.readErr != nil {
return 0, m.readErr return 0, m.readErr
@ -50,13 +55,11 @@ func (m *mockBody) Close() error {
return m.closeErr return m.closeErr
} }
// make sure the mockBody can be used as a http.Request.Body
var _ io.ReadCloser = &mockBody{}
var _ = Describe("RoundTripper", func() { var _ = Describe("RoundTripper", func() {
var ( var (
rt *RoundTripper rt *RoundTripper
req1 *http.Request req1 *http.Request
session *mockquic.MockSession
) )
BeforeEach(func() { BeforeEach(func() {
@ -68,14 +71,14 @@ var _ = Describe("RoundTripper", func() {
Context("dialing hosts", func() { Context("dialing hosts", func() {
origDialAddr := dialAddr origDialAddr := dialAddr
streamOpenErr := errors.New("error opening stream")
BeforeEach(func() { BeforeEach(func() {
session = mockquic.NewMockSession(mockCtrl)
origDialAddr = dialAddr origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
// return an error when trying to open a stream // return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all // we don't want to test all the dial logic here, just that dialing happens at all
return &mockSession{streamOpenErr: streamOpenErr}, nil return session, nil
} }
}) })
@ -84,11 +87,17 @@ var _ = Describe("RoundTripper", func() {
}) })
It("creates new clients", func() { It("creates new clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, error) { close(closed) })
_, err = rt.RoundTrip(req) _, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1)) Expect(rt.clients).To(HaveLen(1))
Eventually(closed).Should(BeClosed())
}) })
It("uses the quic.Config, if provided", func() { It("uses the quic.Config, if provided", func() {
@ -96,35 +105,43 @@ var _ = Describe("RoundTripper", func() {
var receivedConfig *quic.Config var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
receivedConfig = config receivedConfig = config
return nil, errors.New("err") return nil, errors.New("handshake error")
} }
rt.QuicConfig = config rt.QuicConfig = config
rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(receivedConfig).To(Equal(config)) Expect(err).To(MatchError("handshake error"))
Expect(receivedConfig.HandshakeTimeout).To(Equal(config.HandshakeTimeout))
}) })
It("uses the custom dialer, if provided", func() { It("uses the custom dialer, if provided", func() {
var dialed bool var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) { dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialed = true dialed = true
return nil, errors.New("err") return nil, errors.New("handshake error")
} }
rt.Dial = dialer rt.Dial = dialer
rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("handshake error"))
Expect(dialed).To(BeTrue()) Expect(dialed).To(BeTrue())
}) })
It("reuses existing clients", func() { It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, error) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req) _, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1)) Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2) _, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(streamOpenErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1)) Expect(rt.clients).To(HaveLen(1))
Eventually(closed).Should(BeClosed())
}) })
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
@ -141,7 +158,7 @@ var _ = Describe("RoundTripper", func() {
req.Body = &mockBody{} req.Body = &mockBody{}
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req) _, err = rt.RoundTrip(req)
Expect(err).To(MatchError("quic: unsupported protocol scheme: http")) Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
Expect(req.Body.(*mockBody).closed).To(BeTrue()) Expect(req.Body.(*mockBody).closed).To(BeTrue())
}) })
@ -149,7 +166,7 @@ var _ = Describe("RoundTripper", func() {
req1.URL = nil req1.URL = nil
req1.Body = &mockBody{} req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL")) Expect(err).To(MatchError("http3: nil Request.URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req1.Body.(*mockBody).closed).To(BeTrue())
}) })
@ -157,7 +174,7 @@ var _ = Describe("RoundTripper", func() {
req1.URL.Host = "" req1.URL.Host = ""
req1.Body = &mockBody{} req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: no Host in request URL")) Expect(err).To(MatchError("http3: no Host in request URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req1.Body.(*mockBody).closed).To(BeTrue())
}) })
@ -165,34 +182,34 @@ var _ = Describe("RoundTripper", func() {
req1.URL = nil req1.URL = nil
Expect(req1.Body).To(BeNil()) Expect(req1.Body).To(BeNil())
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL")) Expect(err).To(MatchError("http3: nil Request.URL"))
}) })
It("rejects requests without a header", func() { It("rejects requests without a header", func() {
req1.Header = nil req1.Header = nil
req1.Body = &mockBody{} req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.Header")) Expect(err).To(MatchError("http3: nil Request.Header"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req1.Body.(*mockBody).closed).To(BeTrue())
}) })
It("rejects requests with invalid header name fields", func() { It("rejects requests with invalid header name fields", func() {
req1.Header.Add("foobär", "value") req1.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\"")) Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
}) })
It("rejects requests with invalid header name values", func() { It("rejects requests with invalid header name values", func() {
req1.Header.Add("foo", string([]byte{0x7})) req1.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value")) Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
}) })
It("rejects requests with an invalid request method", func() { It("rejects requests with an invalid request method", func() {
req1.Method = "foobär" req1.Method = "foobär"
req1.Body = &mockBody{} req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid method \"foobär\"")) Expect(err).To(MatchError("http3: invalid method \"foobär\""))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req1.Body.(*mockBody).closed).To(BeTrue())
}) })
}) })

View file

@ -1,9 +1,10 @@
package h2quic package http3
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
@ -12,23 +13,12 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/net/http2" "github.com/marten-seemann/qpack"
"golang.org/x/net/http2/hpack"
) )
type streamCreator interface {
quic.Session
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
}
type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr // allows mocking of quic.Listen and quic.ListenAddr
var ( var (
quicListen = quic.Listen quicListen = quic.Listen
@ -43,9 +33,6 @@ type Server struct {
// If nil, it uses reasonable default values. // If nil, it uses reasonable default values.
QuicConfig *quic.Config QuicConfig *quic.Config
// Private flag for demo, do not use
CloseAfterFirstRequest bool
port uint32 // used atomically port uint32 // used atomically
listenerMutex sync.Mutex listenerMutex sync.Mutex
@ -54,18 +41,18 @@ type Server struct {
supportedVersionsAsString string supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl() logger utils.Logger
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
func (s *Server) ListenAndServe() error { func (s *Server) ListenAndServe() error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of http3.Server without http.Server")
} }
return s.serveImpl(s.TLSConfig, nil) return s.serveImpl(s.TLSConfig, nil)
} }
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
var err error var err error
certs := make([]tls.Certificate, 1) certs := make([]tls.Certificate, 1)
@ -88,7 +75,7 @@ func (s *Server) Serve(conn net.PacketConn) error {
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of http3.Server without http.Server")
} }
s.logger = utils.DefaultLogger.WithPrefix("server") s.logger = utils.DefaultLogger.WithPrefix("server")
s.listenerMutex.Lock() s.listenerMutex.Lock()
@ -120,138 +107,104 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if err != nil { if err != nil {
return err return err
} }
go s.handleHeaderStream(sess.(streamCreator)) go s.handleConn(sess)
} }
} }
func (s *Server) handleHeaderStream(session streamCreator) { func (s *Server) handleConn(sess quic.Session) {
stream, err := session.AcceptStream() // TODO: accept control streams
if err != nil { decoder := qpack.NewDecoder(nil)
session.CloseWithError(quic.ErrorCode(qerr.InternalError), err)
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for { for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { str, err := sess.AcceptStream()
// QuicErrors must originate from stream.Read() returning an error. if err != nil {
// In this case, the session has already logged the error, so we don't s.logger.Debugf("Accepting stream failed: %s", err)
// need to log it again.
errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.CloseWithError(quic.ErrorCode(errorCode), err)
return return
} }
// TODO: handle error
go func() {
if err := s.handleRequest(str, decoder); err != nil {
s.logger.Debugf("Handling request failed: %s", err)
str.CancelWrite(quic.ErrorCode(errorGeneralProtocolError))
return
}
str.Close()
}()
} }
} }
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { // TODO: improve error handling.
h2frame, err := h2framer.ReadFrame() // Most (but not all) of the errors occurring here are connection-level erros.
func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) error {
frame, err := parseNextFrame(str)
if err != nil { if err != nil {
return qerr.Error(qerr.InternalError, "cannot read frame") str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
}
var h2headersFrame *http2.HeadersFrame
switch f := h2frame.(type) {
case *http2.PriorityFrame:
// ignore PRIORITY frames
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
return nil
case *http2.HeadersFrame:
h2headersFrame = f
default:
return qerr.Error(qerr.ProtocolViolation, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil {
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err return err
} }
hf, ok := frame.(*headersFrame)
req, err := requestFromHeaders(headers) if !ok {
str.CancelWrite(quic.ErrorCode(errorUnexpectedFrame))
return errors.New("expected first frame to be a headers frame")
}
// TODO: check length
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
str.CancelWrite(quic.ErrorCode(errorIncompleteRequest))
return err
}
hfs, err := decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
str.CancelWrite(quic.ErrorCode(errorGeneralProtocolError))
return err
}
req, err := requestFromHeaders(hfs)
if err != nil { if err != nil {
return err return err
} }
req.Body = newRequestBody(str)
if s.logger.Debug() { if s.logger.Debug() {
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
} else { } else {
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
} }
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) req = req.WithContext(str.Context())
if err != nil { responseWriter := newResponseWriter(str, s.logger)
return err handler := s.Handler
} if handler == nil {
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request handler = http.DefaultServeMux
if dataStream == nil {
return nil
} }
// handleRequest should be as non-blocking as possible to minimize var panicked, readEOF bool
// head-of-line blocking. Potentially blocking code is run in a separate func() {
// goroutine, enabling handleRequest to return before the code is executed. defer func() {
go func() { if p := recover(); p != nil {
streamEnded := h2headersFrame.StreamEnded() // Copied from net/http/server.go
if streamEnded { const size = 64 << 10
dataStream.(remoteCloser).CloseRemote(0) buf := make([]byte, size)
streamEnded = true buf = buf[:runtime.Stack(buf, false)]
_, _ = dataStream.Read([]byte{0}) // read the eof s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
} panicked = true
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
}
panicked := false
func() {
defer func() {
if p := recover(); p != nil {
// Copied from net/http/server.go
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true
}
}()
handler.ServeHTTP(responseWriter, req)
}()
if panicked {
responseWriter.WriteHeader(500)
} else {
responseWriter.WriteHeader(200)
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
} }
responseWriter.dataStream.Close() }()
} handler.ServeHTTP(responseWriter, req)
if s.CloseAfterFirstRequest { // read the eof
time.Sleep(100 * time.Millisecond) if _, err = str.Read([]byte{}); err == io.EOF {
session.Close() readEOF = true
} }
}() }()
if panicked {
responseWriter.WriteHeader(500)
} else {
responseWriter.WriteHeader(200)
}
if !readEOF {
str.CancelRead(quic.ErrorCode(errorEarlyResponse))
}
return nil return nil
} }
@ -310,7 +263,7 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
} }
// ListenAndServeQUIC listens on the UDP network address addr and calls the // ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is // handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil. // used when handler is nil.
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
server := &Server{ server := &Server{

408
http3/server_test.go Normal file
View file

@ -0,0 +1,408 @@
package http3
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Server", func() {
var (
s *Server
// session *mockquic.MockSession
origQuicListenAddr = quicListenAddr
)
BeforeEach(func() {
s = &Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
logger: utils.DefaultLogger,
}
origQuicListenAddr = quicListenAddr
})
AfterEach(func() {
quicListenAddr = origQuicListenAddr
})
Context("handling requests", func() {
var (
qpackDecoder *qpack.Decoder
str *mockquic.MockStream
exampleGetRequest *http.Request
examplePostRequest *http.Request
)
reqContext := context.Background()
decodeHeader := func(str io.Reader) map[string][]string {
fields := make(map[string][]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
Expect(err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
Expect(err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
encodeRequest := func(req *http.Request) []byte {
buf := &bytes.Buffer{}
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
}).AnyTimes()
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
rw := newRequestWriter(utils.DefaultLogger)
Expect(rw.WriteRequest(str, req)).To(Succeed())
if req.Body != nil {
b := make([]byte, 1000)
n, err := io.ReadFull(req.Body, b)
Expect(err).To(Equal(io.ErrUnexpectedEOF)) // otherwise b is too small for this test
(&dataFrame{Length: uint64(n)}).Write(buf)
buf.Write(b[:n])
}
Eventually(closed).Should(BeClosed())
return buf.Bytes()
}
setRequest := func(data []byte) {
buf := bytes.NewBuffer(data)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
if buf.Len() == 0 {
return 0, io.EOF
}
return buf.Read(p)
}).AnyTimes()
}
BeforeEach(func() {
var err error
exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
Expect(err).ToNot(HaveOccurred())
examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
Expect(err).ToNot(HaveOccurred())
qpackDecoder = qpack.NewDecoder(nil)
str = mockquic.NewMockStream(mockCtrl)
})
It("calls the HTTP handler function", func() {
requestChan := make(chan *http.Request, 1)
s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
requestChan <- r
})
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
var req *http.Request
Eventually(requestChan).Should(Receive(&req))
Expect(req.Host).To(Equal("www.example.com"))
})
It("returns 200 with an empty handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("handles a panicking handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
})
It("cancels reading when client sends a body in GET request", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
requestData := encodeRequest(exampleGetRequest)
buf := &bytes.Buffer{}
(&dataFrame{Length: 6}).Write(buf) // add a body
buf.Write([]byte("foobar"))
responseBuf := &bytes.Buffer{}
setRequest(append(requestData, buf.Bytes()...))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return responseBuf.Write(p)
}).AnyTimes()
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("cancels reading when the body of POST request is not read", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
close(handlerCalled)
})
setRequest(encodeRequest(examplePostRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
Eventually(handlerCalled).Should(BeClosed())
})
It("handles a request for which the client immediately resets the stream", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
testErr := errors.New("stream reset")
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled))
Expect(s.handleRequest(str, qpackDecoder)).To(MatchError(testErr))
Consistently(handlerCalled).ShouldNot(BeClosed())
})
It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
close(handlerCalled)
})
setRequest(encodeRequest(examplePostRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
Eventually(handlerCalled).Should(BeClosed())
})
It("cancels the request context when the stream is closed", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Context().Done()).To(BeClosed())
Expect(r.Context().Err()).To(MatchError(context.Canceled))
close(handlerCalled)
})
setRequest(encodeRequest(examplePostRequest))
reqContext, cancel := context.WithCancel(context.Background())
cancel()
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil
}).AnyTimes()
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
Eventually(handlerCalled).Should(BeClosed())
})
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
var versionsAsString []string
for _, v := range versions {
versionsAsString = append(versionsAsString, v.ToAltSvc())
}
return http.Header{
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
}
}
BeforeEach(func() {
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
expected = getExpectedHeader(protocol.SupportedVersions)
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with full addr", func() {
s.Server.Addr = "127.0.0.1:443"
hdr := http.Header{}
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with string port", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
Expect(hdr).To(Equal(expected))
})
It("works multiple times", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
Expect(hdr).To(Equal(expected))
hdr = http.Header{}
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
Expect(hdr).To(Equal(expected))
})
})
It("errors when ListenAndServe is called with s.Server nil", func() {
Expect((&Server{}).ListenAndServe()).To(MatchError("use of http3.Server without http.Server"))
})
It("errors when ListenAndServeTLS is called with s.Server nil", func() {
Expect((&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError("use of http3.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() {
Expect((&Server{}).Close()).To(Succeed())
})
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
Expect(serv.ListenAndServe()).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
if err := s.ListenAndServe(); err != nil {
cErr <- err
}
}()
}
Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once")))
Expect(s.Close()).To(Succeed())
})
It("uses the quic.Config to start the quic server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
s.QuicConfig = conf
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf).To(Equal(conf))
})
})
Context("ListenAndServeTLS", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
if err := s.ListenAndServeTLS(testdata.GetCertificatePaths()); err != nil {
cErr <- err
}
}()
}
Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once")))
Expect(s.Close()).To(Succeed())
})
})
It("closes gracefully", func() {
Expect(s.CloseGracefully(0)).To(Succeed())
})
It("errors when listening fails", func() {
testErr := errors.New("listen error")
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
return nil, testErr
}
fullpem, privkey := testdata.GetCertificatePaths()
Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
})
})

View file

@ -9,7 +9,7 @@ import (
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver" "github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
@ -38,7 +38,7 @@ var _ = Describe("HTTP tests", func() {
Context(fmt.Sprintf("with QUIC version %s", version), func() { Context(fmt.Sprintf("with QUIC version %s", version), func() {
BeforeEach(func() { BeforeEach(func() {
client = &http.Client{ client = &http.Client{
Transport: &h2quic.RoundTripper{ Transport: &http3.RoundTripper{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(), RootCAs: testdata.GetRootCA(),
}, },
@ -77,8 +77,20 @@ var _ = Describe("HTTP tests", func() {
Expect(body).To(Equal(testserver.PRDataLong)) Expect(body).To(Equal(testserver.PRDataLong))
}) })
// TODO(#1756): this test times out It("downloads many hellos", func() {
PIt("downloads many files, if the response is not read", func() { const num = 150
for i := 0; i < num; i++ {
resp, err := client.Get("https://localhost:" + testserver.Port() + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World!\n"))
}
})
It("downloads many files, if the response is not read", func() {
const num = 150 const num = 150
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
@ -89,6 +101,19 @@ var _ = Describe("HTTP tests", func() {
} }
}) })
It("posts a small message", func() {
resp, err := client.Post(
"https://localhost:"+testserver.Port()+"/echo",
"text/plain",
bytes.NewReader([]byte("Hello, world!")),
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal([]byte("Hello, world!")))
})
It("uploads a file", func() { It("uploads a file", func() {
resp, err := client.Post( resp, err := client.Post(
"https://localhost:"+testserver.Port()+"/echo", "https://localhost:"+testserver.Port()+"/echo",
@ -99,7 +124,7 @@ var _ = Describe("HTTP tests", func() {
Expect(resp.StatusCode).To(Equal(200)) Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue()) Expect(body).To(Equal(testserver.PRData))
}) })
}) })
} }

View file

@ -8,7 +8,7 @@ import (
"strconv" "strconv"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
@ -27,7 +27,7 @@ var (
// PRDataLong contains dataLenLong bytes of pseudo-random data. // PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong) PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server server *http3.Server
stoppedServing chan struct{} stoppedServing chan struct{}
port string port string
) )
@ -75,10 +75,10 @@ func GeneratePRData(l int) []byte {
return res return res
} }
// StartQuicServer starts a h2quic.Server. // StartQuicServer starts a http3.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used. // versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) { func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{ server = &http3.Server{
Server: &http.Server{ Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(), TLSConfig: testdata.GetTLSConfig(),
}, },
@ -102,7 +102,7 @@ func StartQuicServer(versions []protocol.VersionNumber) {
}() }()
} }
// StopQuicServer stops the h2quic.Server. // StopQuicServer stops the http3.Server.
func StopQuicServer() { func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred()) Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed()) Eventually(stoppedServing).Should(BeClosed())

View file

@ -1,5 +1,7 @@
package mocks package mocks
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer" //go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener" //go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"

View file

@ -0,0 +1,213 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: Session)
// Package mockquic is a generated GoMock package.
package mockquic
import (
context "context"
tls "crypto/tls"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
quic_go "github.com/lucas-clemente/quic-go"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockSession is a mock of Session interface
type MockSession struct {
ctrl *gomock.Controller
recorder *MockSessionMockRecorder
}
// MockSessionMockRecorder is the mock recorder for MockSession
type MockSessionMockRecorder struct {
mock *MockSession
}
// NewMockSession creates a new mock instance
func NewMockSession(ctrl *gomock.Controller) *MockSession {
mock := &MockSession{ctrl: ctrl}
mock.recorder = &MockSessionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSession) EXPECT() *MockSessionMockRecorder {
return m.recorder
}
// AcceptStream mocks base method
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream")
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
}
// AcceptUniStream mocks base method
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptUniStream")
ret0, _ := ret[0].(quic_go.ReceiveStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptUniStream indicates an expected call of AcceptUniStream
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
}
// Close mocks base method
func (m *MockSession) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockSessionMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSession)(nil).Close))
}
// CloseWithError mocks base method
func (m *MockSession) CloseWithError(arg0 protocol.ApplicationErrorCode, arg1 error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// CloseWithError indicates an expected call of CloseWithError
func (mr *MockSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockSession)(nil).CloseWithError), arg0, arg1)
}
// ConnectionState mocks base method
func (m *MockSession) ConnectionState() tls.ConnectionState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(tls.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState
func (mr *MockSessionMockRecorder) ConnectionState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockSession)(nil).ConnectionState))
}
// Context mocks base method
func (m *MockSession) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockSessionMockRecorder) Context() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSession)(nil).Context))
}
// LocalAddr mocks base method
func (m *MockSession) LocalAddr() net.Addr {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LocalAddr")
ret0, _ := ret[0].(net.Addr)
return ret0
}
// LocalAddr indicates an expected call of LocalAddr
func (mr *MockSessionMockRecorder) LocalAddr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSession)(nil).LocalAddr))
}
// OpenStream mocks base method
func (m *MockSession) OpenStream() (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStream")
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStream indicates an expected call of OpenStream
func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockSession)(nil).OpenStream))
}
// OpenStreamSync mocks base method
func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync")
ret0, _ := ret[0].(quic_go.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync
func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync))
}
// OpenUniStream mocks base method
func (m *MockSession) OpenUniStream() (quic_go.SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStream")
ret0, _ := ret[0].(quic_go.SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStream indicates an expected call of OpenUniStream
func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockSession)(nil).OpenUniStream))
}
// OpenUniStreamSync mocks base method
func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenUniStreamSync")
ret0, _ := ret[0].(quic_go.SendStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync))
}
// RemoteAddr mocks base method
func (m *MockSession) RemoteAddr() net.Addr {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoteAddr")
ret0, _ := ret[0].(net.Addr)
return ret0
}
// RemoteAddr indicates an expected call of RemoteAddr
func (mr *MockSessionMockRecorder) RemoteAddr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSession)(nil).RemoteAddr))
}

View file

@ -0,0 +1,175 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go (interfaces: Stream)
// Package mockquic is a generated GoMock package.
package mockquic
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockStream is a mock of Stream interface
type MockStream struct {
ctrl *gomock.Controller
recorder *MockStreamMockRecorder
}
// MockStreamMockRecorder is the mock recorder for MockStream
type MockStreamMockRecorder struct {
mock *MockStream
}
// NewMockStream creates a new mock instance
func NewMockStream(ctrl *gomock.Controller) *MockStream {
mock := &MockStream{ctrl: ctrl}
mock.recorder = &MockStreamMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStream) EXPECT() *MockStreamMockRecorder {
return m.recorder
}
// CancelRead mocks base method
func (m *MockStream) CancelRead(arg0 protocol.ApplicationErrorCode) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelRead", arg0)
}
// CancelRead indicates an expected call of CancelRead
func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0)
}
// CancelWrite mocks base method
func (m *MockStream) CancelWrite(arg0 protocol.ApplicationErrorCode) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelWrite", arg0)
}
// CancelWrite indicates an expected call of CancelWrite
func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0)
}
// Close mocks base method
func (m *MockStream) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockStreamMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close))
}
// Context mocks base method
func (m *MockStream) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context
func (mr *MockStreamMockRecorder) Context() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context))
}
// Read mocks base method
func (m *MockStream) Read(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0)
}
// SetDeadline mocks base method
func (m *MockStream) SetDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetDeadline indicates an expected call of SetDeadline
func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0)
}
// SetReadDeadline mocks base method
func (m *MockStream) SetReadDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetReadDeadline indicates an expected call of SetReadDeadline
func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0)
}
// SetWriteDeadline mocks base method
func (m *MockStream) SetWriteDeadline(arg0 time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWriteDeadline indicates an expected call of SetWriteDeadline
func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0)
}
// StreamID mocks base method
func (m *MockStream) StreamID() protocol.StreamID {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StreamID")
ret0, _ := ret[0].(protocol.StreamID)
return ret0
}
// StreamID indicates an expected call of StreamID
func (mr *MockStreamMockRecorder) StreamID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID))
}
// Write mocks base method
func (m *MockStream) Write(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0)
}