mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
implement HTTP/3
This commit is contained in:
parent
1325909ab7
commit
4f6d0e651a
43 changed files with 2511 additions and 2540 deletions
|
@ -1,6 +1,6 @@
|
||||||
run:
|
run:
|
||||||
skip-files:
|
skip-files:
|
||||||
- h2quic/response_writer_closenotifier.go
|
- http3/response_writer_closenotifier.go
|
||||||
- internal/handshake/unsafe_test.go
|
- internal/handshake/unsafe_test.go
|
||||||
|
|
||||||
linters-settings:
|
linters-settings:
|
||||||
|
|
12
README.md
12
README.md
|
@ -34,11 +34,7 @@ Running tests:
|
||||||
|
|
||||||
go test ./...
|
go test ./...
|
||||||
|
|
||||||
### HTTP mapping
|
### QUIC without HTTP/3
|
||||||
|
|
||||||
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
|
|
||||||
|
|
||||||
### QUIC without HTTP/2
|
|
||||||
|
|
||||||
Take a look at [this echo example](example/echo/echo.go).
|
Take a look at [this echo example](example/echo/echo.go).
|
||||||
|
|
||||||
|
@ -50,16 +46,16 @@ See the [example server](example/main.go). Starting a QUIC server is very simila
|
||||||
|
|
||||||
```go
|
```go
|
||||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||||
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
|
http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
|
||||||
```
|
```
|
||||||
|
|
||||||
### As a client
|
### As a client
|
||||||
|
|
||||||
See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`.
|
See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
http.Client{
|
http.Client{
|
||||||
Transport: &h2quic.RoundTripper{},
|
Transport: &http3.RoundTripper{},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ coverage:
|
||||||
- streams_map_incoming_uni.go
|
- streams_map_incoming_uni.go
|
||||||
- streams_map_outgoing_bidi.go
|
- streams_map_outgoing_bidi.go
|
||||||
- streams_map_outgoing_uni.go
|
- streams_map_outgoing_uni.go
|
||||||
- h2quic/gzipreader.go
|
|
||||||
- h2quic/response.go
|
|
||||||
- internal/ackhandler/packet_linkedlist.go
|
- internal/ackhandler/packet_linkedlist.go
|
||||||
- internal/utils/byteinterval_linkedlist.go
|
- internal/utils/byteinterval_linkedlist.go
|
||||||
- internal/utils/packetinterval_linkedlist.go
|
- internal/utils/packetinterval_linkedlist.go
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/http3"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,7 @@ func main() {
|
||||||
}
|
}
|
||||||
logger.SetLogTimeFormat("")
|
logger.SetLogTimeFormat("")
|
||||||
|
|
||||||
roundTripper := &h2quic.RoundTripper{
|
roundTripper := &http3.RoundTripper{
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
RootCAs: testdata.GetRootCA(),
|
RootCAs: testdata.GetRootCA(),
|
||||||
},
|
},
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/http3"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
@ -135,9 +135,9 @@ func main() {
|
||||||
var err error
|
var err error
|
||||||
if *tcp {
|
if *tcp {
|
||||||
certFile, keyFile := testdata.GetCertificatePaths()
|
certFile, keyFile := testdata.GetCertificatePaths()
|
||||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
err = http3.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||||
} else {
|
} else {
|
||||||
server := h2quic.Server{
|
server := http3.Server{
|
||||||
Server: &http.Server{Addr: bCap},
|
Server: &http.Server{Addr: bCap},
|
||||||
}
|
}
|
||||||
err = server.ListenAndServeTLS(testdata.GetCertificatePaths())
|
err = server.ListenAndServeTLS(testdata.GetCertificatePaths())
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -5,9 +5,10 @@ go 1.12
|
||||||
require (
|
require (
|
||||||
github.com/cheekybits/genny v1.0.0
|
github.com/cheekybits/genny v1.0.0
|
||||||
github.com/golang/mock v1.2.0
|
github.com/golang/mock v1.2.0
|
||||||
|
github.com/marten-seemann/qpack v0.1.0
|
||||||
github.com/marten-seemann/qtls v0.2.3
|
github.com/marten-seemann/qtls v0.2.3
|
||||||
github.com/onsi/ginkgo v1.7.0
|
github.com/onsi/ginkgo v1.7.0
|
||||||
github.com/onsi/gomega v1.4.3
|
github.com/onsi/gomega v1.4.3
|
||||||
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25
|
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd
|
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7
|
||||||
)
|
)
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -8,6 +8,8 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
|
github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg=
|
||||||
|
github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI=
|
||||||
github.com/marten-seemann/qtls v0.2.3 h1:0yWJ43C62LsZt08vuQJDK1uC1czUc3FJeCLPoNAI4vA=
|
github.com/marten-seemann/qtls v0.2.3 h1:0yWJ43C62LsZt08vuQJDK1uC1czUc3FJeCLPoNAI4vA=
|
||||||
github.com/marten-seemann/qtls v0.2.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk=
|
github.com/marten-seemann/qtls v0.2.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk=
|
||||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
|
@ -19,6 +21,8 @@ golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 h1:jsG6UpNLt9iAsb0S2AGW28
|
||||||
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7 h1:Qe/u+eY379X4He4GBMFZYu3pmh1ML5yT1aL1ndNM1zQ=
|
||||||
|
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
|
311
h2quic/client.go
311
h2quic/client.go
|
@ -1,311 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
"golang.org/x/net/idna"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type roundTripperOpts struct {
|
|
||||||
DisableCompression bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var dialAddr = quic.DialAddr
|
|
||||||
|
|
||||||
// client is a HTTP2 client doing QUIC requests
|
|
||||||
type client struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
tlsConf *tls.Config
|
|
||||||
config *quic.Config
|
|
||||||
opts *roundTripperOpts
|
|
||||||
|
|
||||||
hostname string
|
|
||||||
handshakeErr error
|
|
||||||
dialOnce sync.Once
|
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
|
||||||
|
|
||||||
session quic.Session
|
|
||||||
headerStream quic.Stream
|
|
||||||
headerErr *qerr.QuicError
|
|
||||||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
|
||||||
requestWriter *requestWriter
|
|
||||||
|
|
||||||
responses map[protocol.StreamID]chan *http.Response
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ http.RoundTripper = &client{}
|
|
||||||
|
|
||||||
var defaultQuicConfig = &quic.Config{KeepAlive: true}
|
|
||||||
|
|
||||||
// newClient creates a new client
|
|
||||||
func newClient(
|
|
||||||
hostname string,
|
|
||||||
tlsConfig *tls.Config,
|
|
||||||
opts *roundTripperOpts,
|
|
||||||
quicConfig *quic.Config,
|
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
|
||||||
) *client {
|
|
||||||
config := defaultQuicConfig
|
|
||||||
if quicConfig != nil {
|
|
||||||
config = quicConfig
|
|
||||||
}
|
|
||||||
return &client{
|
|
||||||
hostname: authorityAddr("https", hostname),
|
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
|
||||||
tlsConf: tlsConfig,
|
|
||||||
config: config,
|
|
||||||
opts: opts,
|
|
||||||
headerErrored: make(chan struct{}),
|
|
||||||
dialer: dialer,
|
|
||||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dial dials the connection
|
|
||||||
func (c *client) dial() error {
|
|
||||||
var err error
|
|
||||||
if c.dialer != nil {
|
|
||||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
|
||||||
} else {
|
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// once the version has been negotiated, open the header stream
|
|
||||||
c.headerStream, err = c.session.OpenStreamSync()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
|
|
||||||
go c.handleHeaderStream()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleHeaderStream() {
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
for err == nil {
|
|
||||||
err = c.readResponse(h2framer, decoder)
|
|
||||||
}
|
|
||||||
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.NoError {
|
|
||||||
c.logger.Debugf("Error handling header stream: %s", err)
|
|
||||||
}
|
|
||||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
|
||||||
// stop all running request
|
|
||||||
close(c.headerErrored)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hframe, ok := frame.(*http2.HeadersFrame)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("not a headers frame")
|
|
||||||
}
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mutex.RLock()
|
|
||||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
|
||||||
c.mutex.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
rsp, err := responseFromHeaders(mhframe)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
responseChan <- rsp
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Roundtrip executes a request and returns a response
|
|
||||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// TODO: add port to address, if it doesn't have one
|
|
||||||
if req.URL.Scheme != "https" {
|
|
||||||
return nil, errors.New("quic http2: unsupported scheme")
|
|
||||||
}
|
|
||||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
|
||||||
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.dialOnce.Do(func() {
|
|
||||||
c.handshakeErr = c.dial()
|
|
||||||
})
|
|
||||||
|
|
||||||
if c.handshakeErr != nil {
|
|
||||||
return nil, c.handshakeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
hasBody := (req.Body != nil)
|
|
||||||
|
|
||||||
responseChan := make(chan *http.Response)
|
|
||||||
dataStream, err := c.session.OpenStreamSync()
|
|
||||||
if err != nil {
|
|
||||||
_ = c.closeWithError(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.mutex.Lock()
|
|
||||||
c.responses[dataStream.StreamID()] = responseChan
|
|
||||||
c.mutex.Unlock()
|
|
||||||
|
|
||||||
var requestedGzip bool
|
|
||||||
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
|
||||||
requestedGzip = true
|
|
||||||
}
|
|
||||||
// TODO: add support for trailers
|
|
||||||
endStream := !hasBody
|
|
||||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.closeWithError(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resc := make(chan error, 1)
|
|
||||||
if hasBody {
|
|
||||||
go func() {
|
|
||||||
resc <- c.writeRequestBody(dataStream, req.Body)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
var res *http.Response
|
|
||||||
|
|
||||||
var receivedResponse bool
|
|
||||||
var bodySent bool
|
|
||||||
|
|
||||||
if !hasBody {
|
|
||||||
bodySent = true
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := req.Context()
|
|
||||||
for !(bodySent && receivedResponse) {
|
|
||||||
select {
|
|
||||||
case res = <-responseChan:
|
|
||||||
receivedResponse = true
|
|
||||||
c.mutex.Lock()
|
|
||||||
delete(c.responses, dataStream.StreamID())
|
|
||||||
c.mutex.Unlock()
|
|
||||||
case err := <-resc:
|
|
||||||
bodySent = true
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
// error code 6 signals that stream was canceled
|
|
||||||
dataStream.CancelRead(6)
|
|
||||||
dataStream.CancelWrite(6)
|
|
||||||
c.mutex.Lock()
|
|
||||||
delete(c.responses, dataStream.StreamID())
|
|
||||||
c.mutex.Unlock()
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-c.headerErrored:
|
|
||||||
// an error occurred on the header stream
|
|
||||||
_ = c.closeWithError(c.headerErr)
|
|
||||||
return nil, c.headerErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: correctly set this variable
|
|
||||||
var streamEnded bool
|
|
||||||
isHead := (req.Method == "HEAD")
|
|
||||||
|
|
||||||
res = setLength(res, isHead, streamEnded)
|
|
||||||
|
|
||||||
if streamEnded || isHead {
|
|
||||||
res.Body = noBody
|
|
||||||
} else {
|
|
||||||
res.Body = &responseBody{dataStream}
|
|
||||||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
|
||||||
res.Header.Del("Content-Encoding")
|
|
||||||
res.Header.Del("Content-Length")
|
|
||||||
res.ContentLength = -1
|
|
||||||
res.Body = &gzipReader{body: res.Body}
|
|
||||||
res.Uncompressed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Request = req
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
|
|
||||||
defer func() {
|
|
||||||
cerr := body.Close()
|
|
||||||
if err == nil {
|
|
||||||
// TODO: what to do with dataStream here? Maybe reset it?
|
|
||||||
err = cerr
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = io.Copy(dataStream, body)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: what to do with dataStream here? Maybe reset it?
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dataStream.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) closeWithError(e error) error {
|
|
||||||
if c.session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the client
|
|
||||||
func (c *client) Close() error {
|
|
||||||
if c.session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// copied from net/transport.go
|
|
||||||
|
|
||||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
|
||||||
// and returns a host:port. The port 443 is added if needed.
|
|
||||||
func authorityAddr(scheme string, authority string) (addr string) {
|
|
||||||
host, port, err := net.SplitHostPort(authority)
|
|
||||||
if err != nil { // authority didn't have a port
|
|
||||||
port = "443"
|
|
||||||
if scheme == "http" {
|
|
||||||
port = "80"
|
|
||||||
}
|
|
||||||
host = authority
|
|
||||||
}
|
|
||||||
if a, err := idna.ToASCII(host); err == nil {
|
|
||||||
host = a
|
|
||||||
}
|
|
||||||
// IPv6 address literal, without a port:
|
|
||||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
|
||||||
return host + ":" + port
|
|
||||||
}
|
|
||||||
return net.JoinHostPort(host, port)
|
|
||||||
}
|
|
|
@ -1,640 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Client", func() {
|
|
||||||
var (
|
|
||||||
client *client
|
|
||||||
session *mockSession
|
|
||||||
headerStream *mockStream
|
|
||||||
req *http.Request
|
|
||||||
origDialAddr = dialAddr
|
|
||||||
)
|
|
||||||
|
|
||||||
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
|
|
||||||
EventuallyWithOffset(0, func() bool {
|
|
||||||
client.mutex.Lock()
|
|
||||||
defer client.mutex.Unlock()
|
|
||||||
_, ok := client.responses[id]
|
|
||||||
return ok
|
|
||||||
}).Should(BeTrue())
|
|
||||||
rspChan := client.responses[5]
|
|
||||||
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
|
|
||||||
rspChan <- rsp
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
origDialAddr = dialAddr
|
|
||||||
hostname := "quic.clemente.io:1337"
|
|
||||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
|
||||||
session = newMockSession()
|
|
||||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
client.session = session
|
|
||||||
|
|
||||||
headerStream = newMockStream(3)
|
|
||||||
client.headerStream = headerStream
|
|
||||||
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
|
|
||||||
var err error
|
|
||||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
dialAddr = origDialAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the TLS config", func() {
|
|
||||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
|
||||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves the QUIC config", func() {
|
|
||||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
|
|
||||||
Expect(client.config).To(Equal(quicConf))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the default QUIC config if none is give", func() {
|
|
||||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.config).ToNot(BeNil())
|
|
||||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("dials", func() {
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
close(headerStream.unblockRead)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
// fmt.Println("done")
|
|
||||||
}()
|
|
||||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when dialing fails", func() {
|
|
||||||
testErr := errors.New("handshake error")
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
|
||||||
var tlsCfg *tls.Config
|
|
||||||
var qCfg *quic.Config
|
|
||||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
|
||||||
tlsCfg = tlsCfgP
|
|
||||||
qCfg = cfg
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
|
||||||
Expect(qCfg).To(Equal(client.config))
|
|
||||||
Expect(tlsCfg).To(Equal(client.tlsConf))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't open a stream", func() {
|
|
||||||
testErr := errors.New("you shall not pass")
|
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
session.streamOpenErr = testErr
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns a request when dial fails", func() {
|
|
||||||
testErr := errors.New("dial error")
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
_, err = client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Doing requests", func() {
|
|
||||||
var request *http.Request
|
|
||||||
var dataStream *mockStream
|
|
||||||
|
|
||||||
getRequest := func(data []byte) *http2.MetaHeadersFrame {
|
|
||||||
r := bytes.NewReader(data)
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, r)
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
return mhframe
|
|
||||||
}
|
|
||||||
|
|
||||||
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
|
|
||||||
fields := make(map[string]string)
|
|
||||||
for _, hf := range f.Fields {
|
|
||||||
fields[hf.Name] = hf.Value
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
dataStream = newMockStream(5)
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
|
||||||
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does a request", func() {
|
|
||||||
teapot := &http.Response{
|
|
||||||
Status: "418 I'm a teapot",
|
|
||||||
StatusCode: 418,
|
|
||||||
}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).To(Equal(teapot))
|
|
||||||
Expect(rsp.Body).To(BeAssignableToTypeOf(&responseBody{}))
|
|
||||||
Expect(rsp.Body.(*responseBody).Stream).To(Equal(dataStream))
|
|
||||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
|
||||||
Expect(rsp.Request).To(Equal(request))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
||||||
injectResponse(5, teapot)
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request without a body is canceled", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.canceledRead).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request with a body is canceled after the body is sent", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
cancel()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.canceledRead).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if a request with a body is canceled before the body is sent", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
request = request.WithContext(ctx)
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
cancel()
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(dataStream.canceledRead).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
|
||||||
Expect(client.headerErrored).ToNot(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the quic client when encountering an error on the header stream", func() {
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(client.headerErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
|
|
||||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns subsequent request if there was an error on the header stream before", func() {
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InternalError))
|
|
||||||
// now that the first request failed due to an error on the header stream, try another request
|
|
||||||
_, nextErr := client.RoundTrip(request)
|
|
||||||
Expect(nextErr).To(MatchError(err))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("blocks if no stream is available", func() {
|
|
||||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
|
||||||
session.blockOpenStreamSync = true
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
// make the go routine return
|
|
||||||
client.Close()
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("validating the address", func() {
|
|
||||||
It("refuses to do requests for the wrong host", func() {
|
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("refuses to do plain HTTP requests", func() {
|
|
||||||
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = client.RoundTrip(req)
|
|
||||||
Expect(err).To(MatchError("quic http2: unsupported scheme"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the port for request URLs without one", func() {
|
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
|
||||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header for requests without a body", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
client.RoundTrip(request)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
||||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header to false for requests with a body", func() {
|
|
||||||
request.Body = &mockBody{}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
client.RoundTrip(request)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
|
||||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("requests containing a Body", func() {
|
|
||||||
var requestBody []byte
|
|
||||||
var response *http.Response
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
requestBody = []byte("request body")
|
|
||||||
body := &mockBody{}
|
|
||||||
body.SetData(requestBody)
|
|
||||||
request.Body = body
|
|
||||||
response = &http.Response{
|
|
||||||
StatusCode: 200,
|
|
||||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
||||||
}
|
|
||||||
// fake a handshake
|
|
||||||
client.dialOnce.Do(func() {})
|
|
||||||
session.streamsToOpen = []quic.Stream{dataStream}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends a request", func() {
|
|
||||||
rspChan := make(chan *http.Response)
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rspChan <- rsp
|
|
||||||
}()
|
|
||||||
injectResponse(5, response)
|
|
||||||
Eventually(rspChan).Should(Receive(Equal(response)))
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
|
|
||||||
Expect(dataStream.closed).To(BeTrue())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error that occurred when reading the body", func() {
|
|
||||||
testErr := errors.New("testErr")
|
|
||||||
request.Body.(*mockBody).readErr = testErr
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the error that occurred when closing the body", func() {
|
|
||||||
testErr := errors.New("testErr")
|
|
||||||
request.Body.(*mockBody).closeErr = testErr
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
Expect(rsp).To(BeNil())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("gzip compression", func() {
|
|
||||||
var gzippedData []byte // a gzipped foobar
|
|
||||||
var response *http.Response
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var b bytes.Buffer
|
|
||||||
w := gzip.NewWriter(&b)
|
|
||||||
w.Write([]byte("foobar"))
|
|
||||||
w.Close()
|
|
||||||
gzippedData = b.Bytes()
|
|
||||||
response = &http.Response{
|
|
||||||
StatusCode: 200,
|
|
||||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds the gzip header to requests", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
|
||||||
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
|
||||||
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
|
|
||||||
data := make([]byte, 6)
|
|
||||||
_, err = io.ReadFull(rsp.Body, data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(data).To(Equal([]byte("foobar")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write(gzippedData)
|
|
||||||
response.Header.Add("Content-Encoding", "gzip")
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
close(dataStream.unblockRead)
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't add gzip if the header disable it", func() {
|
|
||||||
client.opts.DisableCompression = true
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
_, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).ToNot(HaveKey("accept-encoding"))
|
|
||||||
// make the go routine return
|
|
||||||
injectResponse(5, &http.Response{})
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
data := make([]byte, 11)
|
|
||||||
rsp.Body.Read(data)
|
|
||||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
||||||
Expect(data).To(Equal([]byte("not gzipped")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write([]byte("not gzipped"))
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
|
|
||||||
request.Header.Add("accept-encoding", "gzip")
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
rsp, err := client.RoundTrip(request)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data := make([]byte, 12)
|
|
||||||
_, err = rsp.Body.Read(data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
|
||||||
Expect(data).To(Equal([]byte("gzipped data")))
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dataStream.dataToRead.Write([]byte("gzipped data"))
|
|
||||||
injectResponse(5, response)
|
|
||||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
|
||||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
Eventually(done).Should(BeClosed())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling the header stream", func() {
|
|
||||||
var h2framer *http2.Framer
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
|
|
||||||
client.responses[23] = make(chan *http.Response)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads header values from a response", func() {
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
|
|
||||||
headerStream.dataToRead.Write(data)
|
|
||||||
go client.handleHeaderStream()
|
|
||||||
var rsp *http.Response
|
|
||||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
|
||||||
Expect(rsp).ToNot(BeNil())
|
|
||||||
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
|
|
||||||
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
|
|
||||||
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
|
|
||||||
Expect(rsp.Status).To(Equal("302 Found"))
|
|
||||||
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
|
|
||||||
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the H2 frame is not a HeadersFrame", func() {
|
|
||||||
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InternalError, "not a headers frame")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if it can't read the HPACK encoded header fields", func() {
|
|
||||||
h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: 23,
|
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: []byte("invalid HPACK data"),
|
|
||||||
})
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
|
|
||||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors if the stream cannot be found", func() {
|
|
||||||
var headers bytes.Buffer
|
|
||||||
enc := hpack.NewEncoder(&headers)
|
|
||||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
|
||||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: 1337,
|
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: headers.Bytes(),
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
client.handleHeaderStream()
|
|
||||||
Eventually(client.headerErrored).Should(BeClosed())
|
|
||||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InternalError))
|
|
||||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,35 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
// copied from net/transport.go
|
|
||||||
|
|
||||||
// gzipReader wraps a response body so it can lazily
|
|
||||||
// call gzip.NewReader on the first call to Read
|
|
||||||
import (
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// call gzip.NewReader on the first call to Read
|
|
||||||
type gzipReader struct {
|
|
||||||
body io.ReadCloser // underlying Response.Body
|
|
||||||
zr *gzip.Reader // lazily-initialized gzip reader
|
|
||||||
zerr error // sticky error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
|
||||||
if gz.zerr != nil {
|
|
||||||
return 0, gz.zerr
|
|
||||||
}
|
|
||||||
if gz.zr == nil {
|
|
||||||
gz.zr, err = gzip.NewReader(gz.body)
|
|
||||||
if err != nil {
|
|
||||||
gz.zerr = err
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return gz.zr.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gz *gzipReader) Close() error {
|
|
||||||
return gz.body.Close()
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestH2quic(t *testing.T) {
|
|
||||||
RegisterFailHandler(Fail)
|
|
||||||
RunSpecs(t, "H2quic Suite")
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
type requestBody struct {
|
|
||||||
requestRead bool
|
|
||||||
dataStream quic.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure the requestBody can be used as a http.Request.Body
|
|
||||||
var _ io.ReadCloser = &requestBody{}
|
|
||||||
|
|
||||||
func newRequestBody(stream quic.Stream) *requestBody {
|
|
||||||
return &requestBody{dataStream: stream}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *requestBody) Read(p []byte) (int, error) {
|
|
||||||
b.requestRead = true
|
|
||||||
return b.dataStream.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *requestBody) Close() error {
|
|
||||||
// stream's Close() closes the write side, not the read side
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request body", func() {
|
|
||||||
var (
|
|
||||||
stream *mockStream
|
|
||||||
rb *requestBody
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
stream = &mockStream{}
|
|
||||||
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
|
|
||||||
rb = newRequestBody(stream)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads from the stream", func() {
|
|
||||||
b := make([]byte, 10)
|
|
||||||
n, _ := stream.Read(b)
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(b[0:6]).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves if the stream was read from", func() {
|
|
||||||
Expect(rb.requestRead).To(BeFalse())
|
|
||||||
rb.Read(make([]byte, 1))
|
|
||||||
Expect(rb.requestRead).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't close the stream when closing the request body", func() {
|
|
||||||
Expect(stream.closed).To(BeFalse())
|
|
||||||
err := rb.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.closed).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,121 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request", func() {
|
|
||||||
It("populates request", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "content-length", Value: "42"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Method).To(Equal("GET"))
|
|
||||||
Expect(req.URL.Path).To(Equal("/foo"))
|
|
||||||
Expect(req.Proto).To(Equal("HTTP/2.0"))
|
|
||||||
Expect(req.ProtoMajor).To(Equal(2))
|
|
||||||
Expect(req.ProtoMinor).To(Equal(0))
|
|
||||||
Expect(req.ContentLength).To(Equal(int64(42)))
|
|
||||||
Expect(req.Header).To(BeEmpty())
|
|
||||||
Expect(req.Body).To(BeNil())
|
|
||||||
Expect(req.Host).To(Equal("quic.clemente.io"))
|
|
||||||
Expect(req.RequestURI).To(Equal("/foo"))
|
|
||||||
Expect(req.TLS).ToNot(BeNil())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("concatenates the cookie headers", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "cookie", Value: "cookie1=foobar1"},
|
|
||||||
{Name: "cookie", Value: "cookie2=foobar2"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Header).To(Equal(http.Header{
|
|
||||||
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles other headers", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
{Name: "cache-control", Value: "max-age=0"},
|
|
||||||
{Name: "duplicate-header", Value: "1"},
|
|
||||||
{Name: "duplicate-header", Value: "2"},
|
|
||||||
}
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(req.Header).To(Equal(http.Header{
|
|
||||||
"Cache-Control": []string{"max-age=0"},
|
|
||||||
"Duplicate-Header": []string{"1", "2"},
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing path", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing method", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":authority", Value: "quic.clemente.io"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors with missing authority", func() {
|
|
||||||
headers := []hpack.HeaderField{
|
|
||||||
{Name: ":path", Value: "/foo"},
|
|
||||||
{Name: ":method", Value: "GET"},
|
|
||||||
}
|
|
||||||
_, err := requestFromHeaders(headers)
|
|
||||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("extracting the hostname from a request", func() {
|
|
||||||
var url *url.URL
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
var err error
|
|
||||||
url, err = url.Parse("https://quic.clemente.io:1337")
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses req.URL.Host", func() {
|
|
||||||
req := &http.Request{URL: url}
|
|
||||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses req.URL.Host even if req.Host is available", func() {
|
|
||||||
req := &http.Request{
|
|
||||||
Host: "www.example.org",
|
|
||||||
URL: url,
|
|
||||||
}
|
|
||||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns an empty hostname if nothing is set", func() {
|
|
||||||
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,203 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/net/http/httpguts"
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type requestWriter struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
headerStream quic.Stream
|
|
||||||
|
|
||||||
henc *hpack.Encoder
|
|
||||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
const defaultUserAgent = "quic-go"
|
|
||||||
|
|
||||||
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
|
|
||||||
rw := &requestWriter{
|
|
||||||
headerStream: headerStream,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
|
||||||
return rw
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
|
|
||||||
// TODO: add support for trailers
|
|
||||||
// TODO: add support for gzip compression
|
|
||||||
// TODO: write continuation frames, if the header frame is too long
|
|
||||||
|
|
||||||
w.mutex.Lock()
|
|
||||||
defer w.mutex.Unlock()
|
|
||||||
|
|
||||||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
|
|
||||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
|
||||||
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: uint32(dataStreamID),
|
|
||||||
EndHeaders: true,
|
|
||||||
EndStream: endStream,
|
|
||||||
BlockFragment: w.hbuf.Bytes(),
|
|
||||||
Priority: http2.PriorityParam{Weight: 0xff},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// the rest of this files is copied from http2.Transport
|
|
||||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
|
||||||
w.hbuf.Reset()
|
|
||||||
|
|
||||||
host := req.Host
|
|
||||||
if host == "" {
|
|
||||||
host = req.URL.Host
|
|
||||||
}
|
|
||||||
host, err := httpguts.PunycodeHostPort(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var path string
|
|
||||||
if req.Method != "CONNECT" {
|
|
||||||
path = req.URL.RequestURI()
|
|
||||||
if !validPseudoPath(path) {
|
|
||||||
orig := path
|
|
||||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
|
||||||
if !validPseudoPath(path) {
|
|
||||||
if req.URL.Opaque != "" {
|
|
||||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for any invalid headers and return an error before we
|
|
||||||
// potentially pollute our hpack state. (We want to be able to
|
|
||||||
// continue to reuse the hpack encoder for future requests)
|
|
||||||
for k, vv := range req.Header {
|
|
||||||
if !httpguts.ValidHeaderFieldName(k) {
|
|
||||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
if !httpguts.ValidHeaderFieldValue(v) {
|
|
||||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 8.1.2.3 Request Pseudo-Header Fields
|
|
||||||
// The :path pseudo-header field includes the path and query parts of the
|
|
||||||
// target URI (the path-absolute production and optionally a '?' character
|
|
||||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
|
||||||
// [RFC3986]).
|
|
||||||
w.writeHeader(":authority", host)
|
|
||||||
w.writeHeader(":method", req.Method)
|
|
||||||
if req.Method != "CONNECT" {
|
|
||||||
w.writeHeader(":path", path)
|
|
||||||
w.writeHeader(":scheme", req.URL.Scheme)
|
|
||||||
}
|
|
||||||
if trailers != "" {
|
|
||||||
w.writeHeader("trailer", trailers)
|
|
||||||
}
|
|
||||||
|
|
||||||
var didUA bool
|
|
||||||
for k, vv := range req.Header {
|
|
||||||
lowKey := strings.ToLower(k)
|
|
||||||
switch lowKey {
|
|
||||||
case "host", "content-length":
|
|
||||||
// Host is :authority, already sent.
|
|
||||||
// Content-Length is automatic, set below.
|
|
||||||
continue
|
|
||||||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
|
||||||
// Per 8.1.2.2 Connection-Specific Header
|
|
||||||
// Fields, don't send connection-specific
|
|
||||||
// fields. We have already checked if any
|
|
||||||
// are error-worthy so just ignore the rest.
|
|
||||||
continue
|
|
||||||
case "user-agent":
|
|
||||||
// Match Go's http1 behavior: at most one
|
|
||||||
// User-Agent. If set to nil or empty string,
|
|
||||||
// then omit it. Otherwise if not mentioned,
|
|
||||||
// include the default (below).
|
|
||||||
didUA = true
|
|
||||||
if len(vv) < 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
vv = vv[:1]
|
|
||||||
if vv[0] == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
w.writeHeader(lowKey, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
|
||||||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
|
||||||
}
|
|
||||||
if addGzipHeader {
|
|
||||||
w.writeHeader("accept-encoding", "gzip")
|
|
||||||
}
|
|
||||||
if !didUA {
|
|
||||||
w.writeHeader("user-agent", defaultUserAgent)
|
|
||||||
}
|
|
||||||
return w.hbuf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *requestWriter) writeHeader(name, value string) {
|
|
||||||
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
|
|
||||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
|
||||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
|
||||||
// transferWriter.shouldSendContentLength.
|
|
||||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
|
||||||
// -1 means unknown.
|
|
||||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
|
||||||
if contentLength > 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if contentLength < 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// For zero bodies, whether we send a content-length depends on the method.
|
|
||||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
|
||||||
switch method {
|
|
||||||
case "POST", "PUT", "PATCH":
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validPseudoPath(v string) bool {
|
|
||||||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
|
||||||
}
|
|
||||||
|
|
||||||
// actualContentLength returns a sanitized version of
|
|
||||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
|
||||||
// means unknown.
|
|
||||||
func actualContentLength(req *http.Request) int64 {
|
|
||||||
if req.Body == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if req.ContentLength != 0 {
|
|
||||||
return req.ContentLength
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
|
@ -1,114 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Request", func() {
|
|
||||||
var (
|
|
||||||
rw *requestWriter
|
|
||||||
headerStream *mockStream
|
|
||||||
decoder *hpack.Decoder
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
rw = newRequestWriter(headerStream, utils.DefaultLogger)
|
|
||||||
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
})
|
|
||||||
|
|
||||||
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
|
|
||||||
framer := http2.NewFramer(nil, bytes.NewReader(p))
|
|
||||||
frame, err := framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
headerFrame := frame.(*http2.HeadersFrame)
|
|
||||||
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
values := make(map[string]string)
|
|
||||||
for _, headerField := range fields {
|
|
||||||
values[headerField.Name] = headerField.Value
|
|
||||||
}
|
|
||||||
return headerFrame, values
|
|
||||||
}
|
|
||||||
|
|
||||||
It("writes a GET request", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, false)
|
|
||||||
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
|
|
||||||
Expect(headerFrame.HasPriority()).To(BeTrue())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
|
|
||||||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the EndStream header", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, false)
|
|
||||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamEnded()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't set the EndStream header, if requested", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, false, false)
|
|
||||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFrame.StreamEnded()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("requests gzip compression, if requested", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 1337, true, true)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes a POST request", func() {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Add("foo", "bar")
|
|
||||||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
rw.WriteRequest(req, 5, true, false)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
|
||||||
Expect(headerFields).To(HaveKey("content-length"))
|
|
||||||
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(contentLength).To(BeNumerically(">", 0))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends cookies", func() {
|
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cookie1 := &http.Cookie{
|
|
||||||
Name: "Cookie #1",
|
|
||||||
Value: "Value #1",
|
|
||||||
}
|
|
||||||
cookie2 := &http.Cookie{
|
|
||||||
Name: "Cookie #2",
|
|
||||||
Value: "Value #2",
|
|
||||||
}
|
|
||||||
req.AddCookie(cookie1)
|
|
||||||
req.AddCookie(cookie2)
|
|
||||||
rw.WriteRequest(req, 11, true, false)
|
|
||||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
|
||||||
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,95 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/textproto"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// copied from net/http2/transport.go
|
|
||||||
|
|
||||||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
|
|
||||||
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
|
|
||||||
|
|
||||||
// from the handleResponse function
|
|
||||||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
|
||||||
if f.Truncated {
|
|
||||||
return nil, errResponseHeaderListSize
|
|
||||||
}
|
|
||||||
|
|
||||||
status := f.PseudoValue("status")
|
|
||||||
if status == "" {
|
|
||||||
return nil, errors.New("missing status pseudo header")
|
|
||||||
}
|
|
||||||
statusCode, err := strconv.Atoi(status)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("malformed non-numeric status pseudo header")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: handle statusCode == 100
|
|
||||||
|
|
||||||
header := make(http.Header)
|
|
||||||
res := &http.Response{
|
|
||||||
Proto: "HTTP/2.0",
|
|
||||||
ProtoMajor: 2,
|
|
||||||
Header: header,
|
|
||||||
StatusCode: statusCode,
|
|
||||||
Status: status + " " + http.StatusText(statusCode),
|
|
||||||
}
|
|
||||||
for _, hf := range f.RegularFields() {
|
|
||||||
key := http.CanonicalHeaderKey(hf.Name)
|
|
||||||
if key == "Trailer" {
|
|
||||||
t := res.Trailer
|
|
||||||
if t == nil {
|
|
||||||
t = make(http.Header)
|
|
||||||
res.Trailer = t
|
|
||||||
}
|
|
||||||
foreachHeaderElement(hf.Value, func(v string) {
|
|
||||||
t[http.CanonicalHeaderKey(v)] = nil
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
header[key] = append(header[key], hf.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// continuation of the handleResponse function
|
|
||||||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
|
|
||||||
if !streamEnded || isHead {
|
|
||||||
res.ContentLength = -1
|
|
||||||
if clens := res.Header["Content-Length"]; len(clens) == 1 {
|
|
||||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
|
||||||
res.ContentLength = clen64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// copied from net/http/server.go
|
|
||||||
|
|
||||||
// foreachHeaderElement splits v according to the "#rule" construction
|
|
||||||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
|
||||||
func foreachHeaderElement(v string, fn func(string)) {
|
|
||||||
v = textproto.TrimString(v)
|
|
||||||
if v == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.Contains(v, ",") {
|
|
||||||
fn(v)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, f := range strings.Split(v, ",") {
|
|
||||||
if f = textproto.TrimString(f); f != "" {
|
|
||||||
fn(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,163 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockStream struct {
|
|
||||||
id protocol.StreamID
|
|
||||||
dataToRead bytes.Buffer
|
|
||||||
dataWritten bytes.Buffer
|
|
||||||
canceledRead bool
|
|
||||||
canceledWrite bool
|
|
||||||
closed bool
|
|
||||||
remoteClosed bool
|
|
||||||
|
|
||||||
unblockRead chan struct{}
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ quic.Stream = &mockStream{}
|
|
||||||
|
|
||||||
func newMockStream(id protocol.StreamID) *mockStream {
|
|
||||||
s := &mockStream{
|
|
||||||
id: id,
|
|
||||||
unblockRead: make(chan struct{}),
|
|
||||||
}
|
|
||||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
|
|
||||||
func (s *mockStream) CancelRead(quic.ErrorCode) { s.canceledRead = true }
|
|
||||||
func (s *mockStream) CancelWrite(quic.ErrorCode) { s.canceledWrite = true }
|
|
||||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
|
|
||||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
|
||||||
func (s *mockStream) Context() context.Context { return s.ctx }
|
|
||||||
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
|
||||||
|
|
||||||
func (s *mockStream) Read(p []byte) (int, error) {
|
|
||||||
n, _ := s.dataToRead.Read(p)
|
|
||||||
if n == 0 { // block if there's no data
|
|
||||||
<-s.unblockRead
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
return n, nil // never return an EOF
|
|
||||||
}
|
|
||||||
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
|
|
||||||
|
|
||||||
var _ = Describe("Response Writer", func() {
|
|
||||||
var (
|
|
||||||
w *responseWriter
|
|
||||||
headerStream *mockStream
|
|
||||||
dataStream *mockStream
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
dataStream = &mockStream{}
|
|
||||||
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
|
|
||||||
})
|
|
||||||
|
|
||||||
decodeHeaderFields := func() map[string][]string {
|
|
||||||
fields := make(map[string][]string)
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
|
|
||||||
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
|
|
||||||
hframe := frame.(*http2.HeadersFrame)
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
|
||||||
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
for _, p := range mhframe.Fields {
|
|
||||||
fields[p.Name] = append(fields[p.Name], p.Value)
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
It("writes status", func() {
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveLen(1))
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes headers", func() {
|
|
||||||
w.Header().Add("content-length", "42")
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes multiple headers with the same name", func() {
|
|
||||||
const cookie1 = "test1=1; Max-Age=7200; path=/"
|
|
||||||
const cookie2 = "test2=2; Max-Age=7200; path=/"
|
|
||||||
w.Header().Add("set-cookie", cookie1)
|
|
||||||
w.Header().Add("set-cookie", cookie2)
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKey("set-cookie"))
|
|
||||||
cookies := fields["set-cookie"]
|
|
||||||
Expect(cookies).To(ContainElement(cookie1))
|
|
||||||
Expect(cookies).To(ContainElement(cookie2))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes data", func() {
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Should have written 200 on the header stream
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
|
||||||
// And foobar on the data stream
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("writes data after WriteHeader is called", func() {
|
|
||||||
w.WriteHeader(http.StatusTeapot)
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(Equal(6))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
// Should have written 418 on the header stream
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
|
||||||
// And foobar on the data stream
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not WriteHeader() twice", func() {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
w.WriteHeader(500)
|
|
||||||
fields := decodeHeaderFields()
|
|
||||||
Expect(fields).To(HaveLen(1))
|
|
||||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow writes if the status code doesn't allow a body", func() {
|
|
||||||
w.WriteHeader(304)
|
|
||||||
n, err := w.Write([]byte("foobar"))
|
|
||||||
Expect(n).To(BeZero())
|
|
||||||
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
|
|
||||||
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,536 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockSession struct {
|
|
||||||
closed bool
|
|
||||||
closedWithError error
|
|
||||||
dataStream quic.Stream
|
|
||||||
streamToAccept quic.Stream
|
|
||||||
streamsToOpen []quic.Stream
|
|
||||||
blockOpenStreamSync bool
|
|
||||||
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
|
|
||||||
streamOpenErr error
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockSession() *mockSession {
|
|
||||||
return &mockSession{blockOpenStreamChan: make(chan struct{})}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
|
||||||
return s.dataStream, nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
|
|
||||||
func (s *mockSession) OpenStream() (quic.Stream, error) {
|
|
||||||
if s.streamOpenErr != nil {
|
|
||||||
return nil, s.streamOpenErr
|
|
||||||
}
|
|
||||||
str := s.streamsToOpen[0]
|
|
||||||
s.streamsToOpen = s.streamsToOpen[1:]
|
|
||||||
return str, nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
|
|
||||||
if s.blockOpenStreamSync {
|
|
||||||
<-s.blockOpenStreamChan
|
|
||||||
}
|
|
||||||
return s.OpenStream()
|
|
||||||
}
|
|
||||||
func (s *mockSession) Close() error {
|
|
||||||
s.ctxCancel()
|
|
||||||
if !s.closed {
|
|
||||||
close(s.blockOpenStreamChan)
|
|
||||||
}
|
|
||||||
s.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
|
|
||||||
s.closedWithError = e
|
|
||||||
return s.Close()
|
|
||||||
}
|
|
||||||
func (s *mockSession) LocalAddr() net.Addr {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
func (s *mockSession) RemoteAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
|
||||||
}
|
|
||||||
func (s *mockSession) Context() context.Context {
|
|
||||||
return s.ctx
|
|
||||||
}
|
|
||||||
func (s *mockSession) ConnectionState() tls.ConnectionState { panic("not implemented") }
|
|
||||||
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
|
|
||||||
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
|
|
||||||
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
|
|
||||||
|
|
||||||
var _ = Describe("H2 server", func() {
|
|
||||||
var (
|
|
||||||
s *Server
|
|
||||||
session *mockSession
|
|
||||||
dataStream *mockStream
|
|
||||||
origQuicListenAddr = quicListenAddr
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
s = &Server{
|
|
||||||
Server: &http.Server{
|
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
|
||||||
},
|
|
||||||
logger: utils.DefaultLogger,
|
|
||||||
}
|
|
||||||
dataStream = newMockStream(0)
|
|
||||||
close(dataStream.unblockRead)
|
|
||||||
session = newMockSession()
|
|
||||||
session.dataStream = dataStream
|
|
||||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
|
||||||
origQuicListenAddr = quicListenAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
quicListenAddr = origQuicListenAddr
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("handling requests", func() {
|
|
||||||
var (
|
|
||||||
h2framer *http2.Framer
|
|
||||||
hpackDecoder *hpack.Decoder
|
|
||||||
headerStream *mockStream
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
headerStream = &mockStream{}
|
|
||||||
hpackDecoder = hpack.NewDecoder(4096, nil)
|
|
||||||
h2framer = http2.NewFramer(nil, headerStream)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles a sample GET request", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledRead).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns 200 with an empty handler", func() {
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() []byte {
|
|
||||||
return headerStream.dataWritten.Bytes()
|
|
||||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
|
|
||||||
})
|
|
||||||
|
|
||||||
It("correctly handles a panicking handler", func() {
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
panic("foobar")
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() []byte {
|
|
||||||
return headerStream.dataWritten.Bytes()
|
|
||||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when client sends a body in GET request", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when the body of POST request is not read", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.Method).To(Equal("POST"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
|
|
||||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
|
||||||
Expect(handlerCalled).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles a request for which the client immediately resets the data stream", func() {
|
|
||||||
session.dataStream = nil
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
r.Body = struct {
|
|
||||||
io.Reader
|
|
||||||
io.Closer
|
|
||||||
}{}
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return dataStream.canceledRead }).Should(BeTrue())
|
|
||||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
|
||||||
Expect(handlerCalled).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the dataStream if the body of POST request was read", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
Expect(r.Method).To(Equal("POST"))
|
|
||||||
handlerCalled = true
|
|
||||||
// read the request body
|
|
||||||
b := make([]byte, 1000)
|
|
||||||
n, _ := r.Body.Read(b)
|
|
||||||
Expect(n).ToNot(BeZero())
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
|
||||||
dataStream.dataToRead.Write([]byte("foo=bar"))
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.canceledRead).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores PRIORITY frames", func() {
|
|
||||||
handlerCalled := make(chan struct{})
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
close(handlerCalled)
|
|
||||||
})
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
framer := http2.NewFramer(buf, nil)
|
|
||||||
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(buf.Bytes()).ToNot(BeEmpty())
|
|
||||||
headerStream.dataToRead.Write(buf.Bytes())
|
|
||||||
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Consistently(handlerCalled).ShouldNot(BeClosed())
|
|
||||||
Expect(dataStream.canceledRead).To(BeFalse())
|
|
||||||
Expect(dataStream.closed).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when non-header frames are received", func() {
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
'f', 'o', 'o', 'b', 'a', 'r',
|
|
||||||
})
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: expected a header frame"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("Cancels the request context when the datstream is closed", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := r.Context().Err()
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(err.Error()).To(Equal("context canceled"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
dataStream.Close()
|
|
||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
|
||||||
Expect(dataStream.canceledRead).To(BeFalse())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("handles the header stream", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the connection if it encounters an error on the header stream", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handlerCalled = true
|
|
||||||
})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
|
||||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
|
||||||
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InternalError, "cannot read frame")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("supports closing after first request", func() {
|
|
||||||
s.CloseAfterFirstRequest = true
|
|
||||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
Expect(session.closed).To(BeFalse())
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the default handler as fallback", func() {
|
|
||||||
var handlerCalled bool
|
|
||||||
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Expect(r.Host).To(Equal("www.example.com"))
|
|
||||||
handlerCalled = true
|
|
||||||
}))
|
|
||||||
headerStream := &mockStream{id: 3}
|
|
||||||
headerStream.dataToRead.Write([]byte{
|
|
||||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
|
||||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
|
||||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
|
||||||
})
|
|
||||||
session.streamToAccept = headerStream
|
|
||||||
go s.handleHeaderStream(session)
|
|
||||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("setting http headers", func() {
|
|
||||||
var expected http.Header
|
|
||||||
|
|
||||||
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
|
|
||||||
var versionsAsString []string
|
|
||||||
for _, v := range versions {
|
|
||||||
versionsAsString = append(versionsAsString, v.ToAltSvc())
|
|
||||||
}
|
|
||||||
return http.Header{
|
|
||||||
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
|
|
||||||
expected = getExpectedHeader(protocol.SupportedVersions)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with numeric port", func() {
|
|
||||||
s.Server.Addr = ":443"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with full addr", func() {
|
|
||||||
s.Server.Addr = "127.0.0.1:443"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets proper headers with string port", func() {
|
|
||||||
s.Server.Addr = ":https"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("works multiple times", func() {
|
|
||||||
s.Server.Addr = ":https"
|
|
||||||
hdr := http.Header{}
|
|
||||||
err := s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
hdr = http.Header{}
|
|
||||||
err = s.SetQuicHeaders(hdr)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(hdr).To(Equal(expected))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should error when ListenAndServe is called with s.Server nil", func() {
|
|
||||||
err := (&Server{}).ListenAndServe()
|
|
||||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
|
|
||||||
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
|
|
||||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("should nop-Close() when s.server is nil", func() {
|
|
||||||
err := (&Server{}).Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when ListenAndServer is called after Close", func() {
|
|
||||||
serv := &Server{Server: &http.Server{}}
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
|
||||||
err := serv.ListenAndServe()
|
|
||||||
Expect(err).To(MatchError("Server is already closed"))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ListenAndServe", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
s.Server.Addr = "localhost:0"
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
Expect(s.Close()).To(Succeed())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("may only be called once", func() {
|
|
||||||
cErr := make(chan error)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := s.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
cErr <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
err := <-cErr
|
|
||||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
|
||||||
err = s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}, 0.5)
|
|
||||||
|
|
||||||
It("uses the quic.Config to start the quic server", func() {
|
|
||||||
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
|
||||||
var receivedConf *quic.Config
|
|
||||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
|
||||||
receivedConf = config
|
|
||||||
return nil, errors.New("listen err")
|
|
||||||
}
|
|
||||||
s.QuicConfig = conf
|
|
||||||
go s.ListenAndServe()
|
|
||||||
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("ListenAndServeTLS", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
s.Server.Addr = "localhost:0"
|
|
||||||
})
|
|
||||||
|
|
||||||
AfterEach(func() {
|
|
||||||
err := s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("may only be called once", func() {
|
|
||||||
cErr := make(chan error)
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
|
|
||||||
if err != nil {
|
|
||||||
cErr <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
err := <-cErr
|
|
||||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
|
||||||
err = s.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
}, 0.5)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes gracefully", func() {
|
|
||||||
err := s.CloseGracefully(0)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when listening fails", func() {
|
|
||||||
testErr := errors.New("listen error")
|
|
||||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
|
||||||
return nil, testErr
|
|
||||||
}
|
|
||||||
fullpem, privkey := testdata.GetCertificatePaths()
|
|
||||||
err := ListenAndServeQUIC("", fullpem, privkey, nil)
|
|
||||||
Expect(err).To(MatchError(testErr))
|
|
||||||
})
|
|
||||||
})
|
|
68
http3/body.go
Normal file
68
http3/body.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The body of a http.Request or http.Response.
|
||||||
|
type body struct {
|
||||||
|
str io.ReadCloser
|
||||||
|
|
||||||
|
isRequest bool
|
||||||
|
|
||||||
|
bytesRemainingInFrame uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ io.ReadCloser = &body{}
|
||||||
|
|
||||||
|
func newRequestBody(str io.ReadCloser) *body {
|
||||||
|
return &body{
|
||||||
|
str: str,
|
||||||
|
isRequest: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponseBody(str io.ReadCloser) *body {
|
||||||
|
return &body{str: str}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *body) Read(b []byte) (int, error) {
|
||||||
|
if r.bytesRemainingInFrame == 0 {
|
||||||
|
parseLoop:
|
||||||
|
for {
|
||||||
|
frame, err := parseNextFrame(r.str)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
switch f := frame.(type) {
|
||||||
|
case *headersFrame:
|
||||||
|
// skip HEADERS frames
|
||||||
|
continue
|
||||||
|
case *dataFrame:
|
||||||
|
r.bytesRemainingInFrame = f.Length
|
||||||
|
break parseLoop
|
||||||
|
default:
|
||||||
|
return 0, errors.New("unexpected frame")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
if r.bytesRemainingInFrame < uint64(len(b)) {
|
||||||
|
n, err = r.str.Read(b[:r.bytesRemainingInFrame])
|
||||||
|
} else {
|
||||||
|
n, err = r.str.Read(b)
|
||||||
|
}
|
||||||
|
r.bytesRemainingInFrame -= uint64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *body) Close() error {
|
||||||
|
// quic.Stream.Close() closes the write side, not the read side
|
||||||
|
if r.isRequest {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.str.Close()
|
||||||
|
}
|
150
http3/body_test.go
Normal file
150
http3/body_test.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
type closingBuffer struct {
|
||||||
|
*bytes.Buffer
|
||||||
|
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *closingBuffer) Close() error { b.closed = true; return nil }
|
||||||
|
|
||||||
|
type bodyType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
bodyTypeRequest bodyType = iota
|
||||||
|
bodyTypeResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Body", func() {
|
||||||
|
var rb *body
|
||||||
|
var buf *bytes.Buffer
|
||||||
|
|
||||||
|
getDataFrame := func(data []byte) []byte {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
(&dataFrame{Length: uint64(len(data))}).Write(b)
|
||||||
|
b.Write(data)
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
buf = &bytes.Buffer{}
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
||||||
|
bodyType := bt
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
cb := &closingBuffer{Buffer: buf}
|
||||||
|
switch bodyType {
|
||||||
|
case bodyTypeRequest:
|
||||||
|
rb = newRequestBody(cb)
|
||||||
|
case bodyTypeResponse:
|
||||||
|
rb = newResponseBody(cb)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames in a single run", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames in multiple runs", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 3)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b).To(Equal([]byte("foo")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b).To(Equal([]byte("bar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames into too large buffers", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 10)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads DATA frames into too large buffers, in multiple runs", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foobar")))
|
||||||
|
b := make([]byte, 4)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(4))
|
||||||
|
Expect(b).To(Equal([]byte("foob")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(2))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("ar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads multiple DATA frames", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foo")))
|
||||||
|
buf.Write(getDataFrame([]byte("bar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("foo")))
|
||||||
|
n, err = rb.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(b[:n]).To(Equal([]byte("bar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("skips HEADERS frames", func() {
|
||||||
|
buf.Write(getDataFrame([]byte("foo")))
|
||||||
|
(&headersFrame{Length: 10}).Write(buf)
|
||||||
|
buf.Write(make([]byte, 10))
|
||||||
|
buf.Write(getDataFrame([]byte("bar")))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := io.ReadFull(rb, b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when it can't parse the frame", func() {
|
||||||
|
buf.Write([]byte("invalid"))
|
||||||
|
_, err := rb.Read([]byte{0})
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors on unexpected frames", func() {
|
||||||
|
(&settingsFrame{}).Write(buf)
|
||||||
|
_, err := rb.Read([]byte{0})
|
||||||
|
Expect(err).To(MatchError("unexpected frame"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
It("closes requests", func() {
|
||||||
|
cb := &closingBuffer{Buffer: buf}
|
||||||
|
rb := newRequestBody(cb)
|
||||||
|
Expect(rb.Close()).To(Succeed())
|
||||||
|
Expect(cb.closed).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes responses", func() {
|
||||||
|
cb := &closingBuffer{Buffer: buf}
|
||||||
|
rb := newResponseBody(cb)
|
||||||
|
Expect(rb.Close()).To(Succeed())
|
||||||
|
Expect(cb.closed).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
182
http3/client.go
Normal file
182
http3/client.go
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultUserAgent = "quic-go HTTP/3"
|
||||||
|
|
||||||
|
var defaultQuicConfig = &quic.Config{KeepAlive: true}
|
||||||
|
|
||||||
|
var dialAddr = quic.DialAddr
|
||||||
|
|
||||||
|
type roundTripperOpts struct {
|
||||||
|
DisableCompression bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// client is a HTTP3 client doing requests
|
||||||
|
type client struct {
|
||||||
|
tlsConf *tls.Config
|
||||||
|
config *quic.Config
|
||||||
|
|
||||||
|
dialOnce sync.Once
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||||
|
handshakeErr error
|
||||||
|
|
||||||
|
requestWriter *requestWriter
|
||||||
|
|
||||||
|
decoder *qpack.Decoder
|
||||||
|
|
||||||
|
hostname string
|
||||||
|
session quic.Session
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClient(
|
||||||
|
hostname string,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
_ *roundTripperOpts, // TODO: implement gzip compression
|
||||||
|
quicConfig *quic.Config,
|
||||||
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||||
|
) *client {
|
||||||
|
if tlsConf == nil {
|
||||||
|
tlsConf = &tls.Config{}
|
||||||
|
}
|
||||||
|
tlsConf.NextProtos = []string{"h3-19"}
|
||||||
|
if quicConfig == nil {
|
||||||
|
quicConfig = defaultQuicConfig
|
||||||
|
}
|
||||||
|
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||||
|
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
||||||
|
|
||||||
|
return &client{
|
||||||
|
hostname: authorityAddr("https", hostname),
|
||||||
|
tlsConf: tlsConf,
|
||||||
|
requestWriter: newRequestWriter(logger),
|
||||||
|
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||||
|
config: quicConfig,
|
||||||
|
dialer: dialer,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dial() error {
|
||||||
|
var err error
|
||||||
|
if c.dialer != nil {
|
||||||
|
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||||
|
} else {
|
||||||
|
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.setupSession(); err != nil {
|
||||||
|
c.session.CloseWithError(quic.ErrorCode(errorInternalError), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// TODO: send a SETTINGS frame
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) setupSession() error {
|
||||||
|
// open the control stream
|
||||||
|
str, err := c.session.OpenUniStreamSync()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
// write the type byte
|
||||||
|
buf.Write([]byte{0x0})
|
||||||
|
// send the SETTINGS frame
|
||||||
|
(&settingsFrame{}).Write(buf)
|
||||||
|
if _, err := str.Write(buf.Bytes()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) Close() error {
|
||||||
|
return c.session.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Roundtrip executes a request and returns a response
|
||||||
|
// TODO: handle request cancelations
|
||||||
|
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.URL.Scheme != "https" {
|
||||||
|
return nil, errors.New("http3: unsupported scheme")
|
||||||
|
}
|
||||||
|
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||||
|
return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.dialOnce.Do(func() {
|
||||||
|
c.handshakeErr = c.dial()
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.handshakeErr != nil {
|
||||||
|
return nil, c.handshakeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := c.session.OpenStreamSync()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.requestWriter.WriteRequest(str, req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hf, ok := frame.(*headersFrame)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("not a HEADERS frame")
|
||||||
|
}
|
||||||
|
// TODO: check size
|
||||||
|
headerBlock := make([]byte, hf.Length)
|
||||||
|
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hfs, err := c.decoder.DecodeFull(headerBlock)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res := &http.Response{
|
||||||
|
Proto: "HTTP/3",
|
||||||
|
ProtoMajor: 3,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: newResponseBody(&responseBody{str}),
|
||||||
|
}
|
||||||
|
for _, hf := range hfs {
|
||||||
|
switch hf.Name {
|
||||||
|
case ":status":
|
||||||
|
status, err := strconv.Atoi(hf.Value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("malformed non-numeric status pseudo header")
|
||||||
|
}
|
||||||
|
res.StatusCode = status
|
||||||
|
res.Status = hf.Value + " " + http.StatusText(status)
|
||||||
|
default:
|
||||||
|
res.Header.Add(hf.Name, hf.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
272
http3/client_test.go
Normal file
272
http3/client_test.go
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Client", func() {
|
||||||
|
var (
|
||||||
|
client *client
|
||||||
|
req *http.Request
|
||||||
|
origDialAddr = dialAddr
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
origDialAddr = dialAddr
|
||||||
|
hostname := "quic.clemente.io:1337"
|
||||||
|
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
dialAddr = origDialAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses the default QUIC config if none is give", func() {
|
||||||
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
var dialAddrCalled bool
|
||||||
|
dialAddr = func(_ string, _ *tls.Config, quicConf *quic.Config) (quic.Session, error) {
|
||||||
|
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||||
|
dialAddrCalled = true
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
}
|
||||||
|
client.RoundTrip(req)
|
||||||
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds the port to the hostname, if none is given", func() {
|
||||||
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
var dialAddrCalled bool
|
||||||
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||||
|
dialAddrCalled = true
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
client.RoundTrip(req)
|
||||||
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses the TLS config and QUIC config", func() {
|
||||||
|
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||||
|
quicConf := &quic.Config{IdleTimeout: time.Nanosecond}
|
||||||
|
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
||||||
|
var dialAddrCalled bool
|
||||||
|
dialAddr = func(
|
||||||
|
hostname string,
|
||||||
|
tlsConfP *tls.Config,
|
||||||
|
quicConfP *quic.Config,
|
||||||
|
) (quic.Session, error) {
|
||||||
|
Expect(hostname).To(Equal("localhost:1337"))
|
||||||
|
Expect(tlsConfP).To(Equal(tlsConf))
|
||||||
|
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
|
||||||
|
dialAddrCalled = true
|
||||||
|
return nil, errors.New("test done")
|
||||||
|
}
|
||||||
|
client.RoundTrip(req)
|
||||||
|
Expect(dialAddrCalled).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses the custom dialer, if provided", func() {
|
||||||
|
testErr := errors.New("test done")
|
||||||
|
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||||
|
quicConf := &quic.Config{IdleTimeout: 1337 * time.Second}
|
||||||
|
var dialerCalled bool
|
||||||
|
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) {
|
||||||
|
Expect(network).To(Equal("udp"))
|
||||||
|
Expect(address).To(Equal("localhost:1337"))
|
||||||
|
Expect(tlsConfP).To(Equal(tlsConf))
|
||||||
|
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
|
||||||
|
dialerCalled = true
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
Expect(dialerCalled).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when dialing fails", func() {
|
||||||
|
testErr := errors.New("handshake error")
|
||||||
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors if it can't open a stream", func() {
|
||||||
|
testErr := errors.New("stream open error")
|
||||||
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
session := mockquic.NewMockSession(mockCtrl)
|
||||||
|
session.EXPECT().OpenUniStreamSync().Return(nil, testErr).MaxTimes(1)
|
||||||
|
session.EXPECT().OpenStreamSync().Return(nil, testErr).MaxTimes(1)
|
||||||
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := client.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("Doing requests", func() {
|
||||||
|
var (
|
||||||
|
request *http.Request
|
||||||
|
str *mockquic.MockStream
|
||||||
|
sess *mockquic.MockSession
|
||||||
|
)
|
||||||
|
|
||||||
|
decodeHeader := func(str io.Reader) map[string]string {
|
||||||
|
fields := make(map[string]string)
|
||||||
|
decoder := qpack.NewDecoder(nil)
|
||||||
|
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
headersFrame := frame.(*headersFrame)
|
||||||
|
data := make([]byte, headersFrame.Length)
|
||||||
|
_, err = io.ReadFull(str, data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hfs, err := decoder.DecodeFull(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
for _, p := range hfs {
|
||||||
|
fields[p.Name] = p.Value
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||||
|
controlStr.EXPECT().Write([]byte{0x0}).Return(1, nil).MaxTimes(1)
|
||||||
|
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
|
||||||
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
|
sess = mockquic.NewMockSession(mockCtrl)
|
||||||
|
sess.EXPECT().OpenUniStreamSync().Return(controlStr, nil).MaxTimes(1)
|
||||||
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends a request", func() {
|
||||||
|
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return buf.Write(p)
|
||||||
|
})
|
||||||
|
str.EXPECT().Close()
|
||||||
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError("test done"))
|
||||||
|
hfs := decodeHeader(buf)
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":scheme", "https"))
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":method", "GET"))
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":authority", "quic.clemente.io:1337"))
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":path", "/file1.dat"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns a response", func() {
|
||||||
|
rspBuf := &bytes.Buffer{}
|
||||||
|
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
||||||
|
rw.WriteHeader(418)
|
||||||
|
|
||||||
|
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||||
|
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||||
|
str.EXPECT().Close()
|
||||||
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return rspBuf.Read(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
rsp, err := client.RoundTrip(request)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rsp.Proto).To(Equal("HTTP/3"))
|
||||||
|
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||||
|
Expect(rsp.StatusCode).To(Equal(418))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("validating the address", func() {
|
||||||
|
It("refuses to do requests for the wrong host", func() {
|
||||||
|
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = client.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("refuses to do plain HTTP requests", func() {
|
||||||
|
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = client.RoundTrip(req)
|
||||||
|
Expect(err).To(MatchError("http3: unsupported scheme"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("requests containing a Body", func() {
|
||||||
|
var strBuf *bytes.Buffer
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
strBuf = &bytes.Buffer{}
|
||||||
|
sess.EXPECT().OpenStreamSync().Return(str, nil)
|
||||||
|
body := &mockBody{}
|
||||||
|
body.SetData([]byte("request body"))
|
||||||
|
var err error
|
||||||
|
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return strBuf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends a request", func() {
|
||||||
|
done := make(chan struct{})
|
||||||
|
str.EXPECT().Close().Do(func() { close(done) })
|
||||||
|
// the response body is sent asynchronously, while already reading the response
|
||||||
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||||
|
<-done
|
||||||
|
return 0, errors.New("test done")
|
||||||
|
})
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError("test done"))
|
||||||
|
hfs := decodeHeader(strBuf)
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the error that occurred when reading the body", func() {
|
||||||
|
request.Body.(*mockBody).readErr = errors.New("testErr")
|
||||||
|
done := make(chan struct{})
|
||||||
|
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
// the response body is sent asynchronously, while already reading the response
|
||||||
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||||
|
<-done
|
||||||
|
return 0, errors.New("test done")
|
||||||
|
})
|
||||||
|
_, err := client.RoundTrip(request)
|
||||||
|
Expect(err).To(MatchError("test done"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
135
http3/frames.go
Normal file
135
http3/frames.go
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type byteReader interface {
|
||||||
|
io.ByteReader
|
||||||
|
io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
type byteReaderImpl struct{ io.Reader }
|
||||||
|
|
||||||
|
func (br *byteReaderImpl) ReadByte() (byte, error) {
|
||||||
|
b := make([]byte, 1)
|
||||||
|
if _, err := br.Reader.Read(b); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return b[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type frame interface{}
|
||||||
|
|
||||||
|
func parseNextFrame(b io.Reader) (frame, error) {
|
||||||
|
br, ok := b.(byteReader)
|
||||||
|
if !ok {
|
||||||
|
br = &byteReaderImpl{b}
|
||||||
|
}
|
||||||
|
t, err := utils.ReadVarInt(br)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l, err := utils.ReadVarInt(br)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case 0x0:
|
||||||
|
return &dataFrame{Length: l}, nil
|
||||||
|
case 0x1:
|
||||||
|
return &headersFrame{Length: l}, nil
|
||||||
|
case 0x4:
|
||||||
|
return parseSettingsFrame(br, l)
|
||||||
|
case 0x2: // PRIORITY
|
||||||
|
fallthrough
|
||||||
|
case 0x3: // CANCEL_PUSH
|
||||||
|
fallthrough
|
||||||
|
case 0x5: // PUSH_PROMISE
|
||||||
|
fallthrough
|
||||||
|
case 0x7: // GOAWAY
|
||||||
|
fallthrough
|
||||||
|
case 0xd: // MAX_PUSH_ID
|
||||||
|
fallthrough
|
||||||
|
case 0xe: // DUPLICATE_PUSH
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// skip over unknown frames
|
||||||
|
if _, err := io.CopyN(ioutil.Discard, br, int64(l)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return parseNextFrame(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dataFrame struct {
|
||||||
|
Length uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *dataFrame) Write(b *bytes.Buffer) {
|
||||||
|
utils.WriteVarInt(b, 0x0)
|
||||||
|
utils.WriteVarInt(b, f.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
type headersFrame struct {
|
||||||
|
Length uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *headersFrame) Write(b *bytes.Buffer) {
|
||||||
|
utils.WriteVarInt(b, 0x1)
|
||||||
|
utils.WriteVarInt(b, f.Length)
|
||||||
|
}
|
||||||
|
|
||||||
|
type settingsFrame struct {
|
||||||
|
settings map[uint64]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
|
||||||
|
if l > 8*(1<<10) {
|
||||||
|
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
|
||||||
|
}
|
||||||
|
buf := make([]byte, l)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
if err == io.ErrUnexpectedEOF {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
frame := &settingsFrame{settings: make(map[uint64]uint64)}
|
||||||
|
b := bytes.NewReader(buf)
|
||||||
|
for b.Len() > 0 {
|
||||||
|
id, err := utils.ReadVarInt(b)
|
||||||
|
if err != nil { // should not happen. We allocated the whole frame already.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
val, err := utils.ReadVarInt(b)
|
||||||
|
if err != nil { // should not happen. We allocated the whole frame already.
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, ok := frame.settings[id]; ok {
|
||||||
|
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||||
|
}
|
||||||
|
frame.settings[id] = val
|
||||||
|
}
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *settingsFrame) Write(b *bytes.Buffer) {
|
||||||
|
utils.WriteVarInt(b, 0x4)
|
||||||
|
var l protocol.ByteCount
|
||||||
|
for id, val := range f.settings {
|
||||||
|
l += utils.VarIntLen(id) + utils.VarIntLen(val)
|
||||||
|
}
|
||||||
|
utils.WriteVarInt(b, uint64(l))
|
||||||
|
for id, val := range f.settings {
|
||||||
|
utils.WriteVarInt(b, id)
|
||||||
|
utils.WriteVarInt(b, val)
|
||||||
|
}
|
||||||
|
}
|
136
http3/frames_test.go
Normal file
136
http3/frames_test.go
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Frames", func() {
|
||||||
|
appendVarInt := func(b []byte, val uint64) []byte {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
utils.WriteVarInt(buf, val)
|
||||||
|
return append(b, buf.Bytes()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
It("skips unknown frames", func() {
|
||||||
|
data := appendVarInt(nil, 0xdeadbeef) // type byte
|
||||||
|
data = appendVarInt(data, 0x42)
|
||||||
|
data = append(data, make([]byte, 0x42)...)
|
||||||
|
buf := bytes.NewBuffer(data)
|
||||||
|
(&dataFrame{Length: 0x1234}).Write(buf)
|
||||||
|
frame, err := parseNextFrame(buf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234)))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("DATA frames", func() {
|
||||||
|
It("parses", func() {
|
||||||
|
data := appendVarInt(nil, 0) // type byte
|
||||||
|
data = appendVarInt(data, 0x1337)
|
||||||
|
frame, err := parseNextFrame(bytes.NewReader(data))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes", func() {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&dataFrame{Length: 0xdeadbeef}).Write(buf)
|
||||||
|
frame, err := parseNextFrame(buf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
Expect(frame.(*dataFrame).Length).To(Equal(uint64(0xdeadbeef)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("HEADERS frames", func() {
|
||||||
|
It("parses", func() {
|
||||||
|
data := appendVarInt(nil, 1) // type byte
|
||||||
|
data = appendVarInt(data, 0x1337)
|
||||||
|
frame, err := parseNextFrame(bytes.NewReader(data))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes", func() {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&headersFrame{Length: 0xdeadbeef}).Write(buf)
|
||||||
|
frame, err := parseNextFrame(buf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
Expect(frame.(*headersFrame).Length).To(Equal(uint64(0xdeadbeef)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("SETTINGS frames", func() {
|
||||||
|
It("parses", func() {
|
||||||
|
settings := appendVarInt(nil, 13)
|
||||||
|
settings = appendVarInt(settings, 37)
|
||||||
|
settings = appendVarInt(settings, 0xdead)
|
||||||
|
settings = appendVarInt(settings, 0xbeef)
|
||||||
|
data := appendVarInt(nil, 4) // type byte
|
||||||
|
data = appendVarInt(data, uint64(len(settings)))
|
||||||
|
data = append(data, settings...)
|
||||||
|
frame, err := parseNextFrame(bytes.NewReader(data))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{}))
|
||||||
|
sf := frame.(*settingsFrame)
|
||||||
|
Expect(sf.settings).To(HaveKeyWithValue(uint64(13), uint64(37)))
|
||||||
|
Expect(sf.settings).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects duplicate settings", func() {
|
||||||
|
settings := appendVarInt(nil, 13)
|
||||||
|
settings = appendVarInt(settings, 37)
|
||||||
|
settings = appendVarInt(settings, 13)
|
||||||
|
settings = appendVarInt(settings, 38)
|
||||||
|
data := appendVarInt(nil, 4) // type byte
|
||||||
|
data = appendVarInt(data, uint64(len(settings)))
|
||||||
|
data = append(data, settings...)
|
||||||
|
_, err := parseNextFrame(bytes.NewReader(data))
|
||||||
|
Expect(err).To(MatchError("duplicate setting: 13"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes", func() {
|
||||||
|
sf := &settingsFrame{settings: map[uint64]uint64{
|
||||||
|
1: 2,
|
||||||
|
99: 999,
|
||||||
|
13: 37,
|
||||||
|
}}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
sf.Write(buf)
|
||||||
|
frame, err := parseNextFrame(buf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(Equal(sf))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors on EOF", func() {
|
||||||
|
sf := &settingsFrame{settings: map[uint64]uint64{
|
||||||
|
13: 37,
|
||||||
|
0xdeadbeef: 0xdecafbad,
|
||||||
|
}}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
sf.Write(buf)
|
||||||
|
|
||||||
|
data := buf.Bytes()
|
||||||
|
_, err := parseNextFrame(bytes.NewReader(data))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
for i := range data {
|
||||||
|
b := make([]byte, i)
|
||||||
|
copy(b, data[:i])
|
||||||
|
_, err := parseNextFrame(bytes.NewReader(b))
|
||||||
|
Expect(err).To(MatchError(io.EOF))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
|
@ -3,6 +3,8 @@ package http3
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
@ -11,3 +13,13 @@ func TestHttp3(t *testing.T) {
|
||||||
RegisterFailHandler(Fail)
|
RegisterFailHandler(Fail)
|
||||||
RunSpecs(t, "HTTP/3 Suite")
|
RunSpecs(t, "HTTP/3 Suite")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mockCtrl *gomock.Controller
|
||||||
|
|
||||||
|
var _ = BeforeEach(func() {
|
||||||
|
mockCtrl = gomock.NewController(GinkgoT())
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = AfterEach(func() {
|
||||||
|
mockCtrl.Finish()
|
||||||
|
})
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -8,10 +8,10 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/net/http2/hpack"
|
"github.com/marten-seemann/qpack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
|
||||||
var path, authority, method, contentLengthStr string
|
var path, authority, method, contentLengthStr string
|
||||||
httpHeaders := http.Header{}
|
httpHeaders := http.Header{}
|
||||||
|
|
||||||
|
@ -57,8 +57,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
||||||
return &http.Request{
|
return &http.Request{
|
||||||
Method: method,
|
Method: method,
|
||||||
URL: u,
|
URL: u,
|
||||||
Proto: "HTTP/2.0",
|
Proto: "HTTP/3",
|
||||||
ProtoMajor: 2,
|
ProtoMajor: 3,
|
||||||
ProtoMinor: 0,
|
ProtoMinor: 0,
|
||||||
Header: httpHeaders,
|
Header: httpHeaders,
|
||||||
Body: nil,
|
Body: nil,
|
312
http3/request_writer.go
Normal file
312
http3/request_writer.go
Normal file
|
@ -0,0 +1,312 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
"golang.org/x/net/http/httpguts"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
"golang.org/x/net/idna"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestWriter struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
encoder *qpack.Encoder
|
||||||
|
headerBuf *bytes.Buffer
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequestWriter(logger utils.Logger) *requestWriter {
|
||||||
|
headerBuf := &bytes.Buffer{}
|
||||||
|
encoder := qpack.NewEncoder(headerBuf)
|
||||||
|
return &requestWriter{
|
||||||
|
encoder: encoder,
|
||||||
|
headerBuf: headerBuf,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error {
|
||||||
|
headers, err := w.getHeaders(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := str.Write(headers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO: add support for trailers
|
||||||
|
if req.Body == nil {
|
||||||
|
str.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the request body asynchronously
|
||||||
|
go func() {
|
||||||
|
if err := w.sendRequestBody(req.Body, str); err != nil {
|
||||||
|
w.logger.Errorf("Error writing request: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
str.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *requestWriter) getHeaders(req *http.Request) ([]byte, error) {
|
||||||
|
w.mutex.Lock()
|
||||||
|
defer w.mutex.Unlock()
|
||||||
|
defer w.encoder.Close()
|
||||||
|
|
||||||
|
if err := w.encodeHeaders(req, false, "", actualContentLength(req)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
hf := headersFrame{Length: uint64(w.headerBuf.Len())}
|
||||||
|
hf.Write(buf)
|
||||||
|
if _, err := io.Copy(buf, w.headerBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
w.headerBuf.Reset()
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *requestWriter) sendRequestBody(req io.ReadCloser, str quic.Stream) error {
|
||||||
|
b := make([]byte, 8*1024)
|
||||||
|
for {
|
||||||
|
n, err := req.Read(b)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||||
|
if _, err := str.Write(buf.Bytes()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := str.Write(b[:n]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copied from net/transport.go
|
||||||
|
|
||||||
|
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
|
||||||
|
host := req.Host
|
||||||
|
if host == "" {
|
||||||
|
host = req.URL.Host
|
||||||
|
}
|
||||||
|
host, err := httpguts.PunycodeHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if req.Method != "CONNECT" {
|
||||||
|
path = req.URL.RequestURI()
|
||||||
|
if !validPseudoPath(path) {
|
||||||
|
orig := path
|
||||||
|
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||||
|
if !validPseudoPath(path) {
|
||||||
|
if req.URL.Opaque != "" {
|
||||||
|
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("invalid request :path %q", orig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for any invalid headers and return an error before we
|
||||||
|
// potentially pollute our hpack state. (We want to be able to
|
||||||
|
// continue to reuse the hpack encoder for future requests)
|
||||||
|
for k, vv := range req.Header {
|
||||||
|
if !httpguts.ValidHeaderFieldName(k) {
|
||||||
|
return fmt.Errorf("invalid HTTP header name %q", k)
|
||||||
|
}
|
||||||
|
for _, v := range vv {
|
||||||
|
if !httpguts.ValidHeaderFieldValue(v) {
|
||||||
|
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enumerateHeaders := func(f func(name, value string)) {
|
||||||
|
// 8.1.2.3 Request Pseudo-Header Fields
|
||||||
|
// The :path pseudo-header field includes the path and query parts of the
|
||||||
|
// target URI (the path-absolute production and optionally a '?' character
|
||||||
|
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||||
|
// [RFC3986]).
|
||||||
|
f(":authority", host)
|
||||||
|
f(":method", req.Method)
|
||||||
|
if req.Method != "CONNECT" {
|
||||||
|
f(":path", path)
|
||||||
|
f(":scheme", req.URL.Scheme)
|
||||||
|
}
|
||||||
|
if trailers != "" {
|
||||||
|
f("trailer", trailers)
|
||||||
|
}
|
||||||
|
|
||||||
|
var didUA bool
|
||||||
|
for k, vv := range req.Header {
|
||||||
|
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
|
||||||
|
// Host is :authority, already sent.
|
||||||
|
// Content-Length is automatic, set below.
|
||||||
|
continue
|
||||||
|
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
|
||||||
|
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
|
||||||
|
strings.EqualFold(k, "keep-alive") {
|
||||||
|
// Per 8.1.2.2 Connection-Specific Header
|
||||||
|
// Fields, don't send connection-specific
|
||||||
|
// fields. We have already checked if any
|
||||||
|
// are error-worthy so just ignore the rest.
|
||||||
|
continue
|
||||||
|
} else if strings.EqualFold(k, "user-agent") {
|
||||||
|
// Match Go's http1 behavior: at most one
|
||||||
|
// User-Agent. If set to nil or empty string,
|
||||||
|
// then omit it. Otherwise if not mentioned,
|
||||||
|
// include the default (below).
|
||||||
|
didUA = true
|
||||||
|
if len(vv) < 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vv = vv[:1]
|
||||||
|
if vv[0] == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range vv {
|
||||||
|
f(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||||
|
f("content-length", strconv.FormatInt(contentLength, 10))
|
||||||
|
}
|
||||||
|
if addGzipHeader {
|
||||||
|
f("accept-encoding", "gzip")
|
||||||
|
}
|
||||||
|
if !didUA {
|
||||||
|
f("user-agent", defaultUserAgent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do a first pass over the headers counting bytes to ensure
|
||||||
|
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
|
||||||
|
// separate pass before encoding the headers to prevent
|
||||||
|
// modifying the hpack state.
|
||||||
|
hlSize := uint64(0)
|
||||||
|
enumerateHeaders(func(name, value string) {
|
||||||
|
hf := hpack.HeaderField{Name: name, Value: value}
|
||||||
|
hlSize += uint64(hf.Size())
|
||||||
|
})
|
||||||
|
|
||||||
|
// TODO: check maximum header list size
|
||||||
|
// if hlSize > cc.peerMaxHeaderListSize {
|
||||||
|
// return errRequestHeaderListSize
|
||||||
|
// }
|
||||||
|
|
||||||
|
// trace := httptrace.ContextClientTrace(req.Context())
|
||||||
|
// traceHeaders := traceHasWroteHeaderField(trace)
|
||||||
|
|
||||||
|
// Header list size is ok. Write the headers.
|
||||||
|
enumerateHeaders(func(name, value string) {
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
|
||||||
|
// if traceHeaders {
|
||||||
|
// traceWroteHeaderField(trace, name, value)
|
||||||
|
// }
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||||
|
// and returns a host:port. The port 443 is added if needed.
|
||||||
|
func authorityAddr(scheme string, authority string) (addr string) {
|
||||||
|
host, port, err := net.SplitHostPort(authority)
|
||||||
|
if err != nil { // authority didn't have a port
|
||||||
|
port = "443"
|
||||||
|
if scheme == "http" {
|
||||||
|
port = "80"
|
||||||
|
}
|
||||||
|
host = authority
|
||||||
|
}
|
||||||
|
if a, err := idna.ToASCII(host); err == nil {
|
||||||
|
host = a
|
||||||
|
}
|
||||||
|
// IPv6 address literal, without a port:
|
||||||
|
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||||
|
return host + ":" + port
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validPseudoPath reports whether v is a valid :path pseudo-header
|
||||||
|
// value. It must be either:
|
||||||
|
//
|
||||||
|
// *) a non-empty string starting with '/'
|
||||||
|
// *) the string '*', for OPTIONS requests.
|
||||||
|
//
|
||||||
|
// For now this is only used a quick check for deciding when to clean
|
||||||
|
// up Opaque URLs before sending requests from the Transport.
|
||||||
|
// See golang.org/issue/16847
|
||||||
|
//
|
||||||
|
// We used to enforce that the path also didn't start with "//", but
|
||||||
|
// Google's GFE accepts such paths and Chrome sends them, so ignore
|
||||||
|
// that part of the spec. See golang.org/issue/19103.
|
||||||
|
func validPseudoPath(v string) bool {
|
||||||
|
return (len(v) > 0 && v[0] == '/') || v == "*"
|
||||||
|
}
|
||||||
|
|
||||||
|
// actualContentLength returns a sanitized version of
|
||||||
|
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||||
|
// means unknown.
|
||||||
|
func actualContentLength(req *http.Request) int64 {
|
||||||
|
if req.Body == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if req.ContentLength != 0 {
|
||||||
|
return req.ContentLength
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||||
|
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||||
|
// transferWriter.shouldSendContentLength.
|
||||||
|
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||||
|
// -1 means unknown.
|
||||||
|
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||||
|
if contentLength > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if contentLength < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// For zero bodies, whether we send a content-length depends on the method.
|
||||||
|
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||||
|
switch method {
|
||||||
|
case "POST", "PUT", "PATCH":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
105
http3/request_writer_test.go
Normal file
105
http3/request_writer_test.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Request Writer", func() {
|
||||||
|
var (
|
||||||
|
rw *requestWriter
|
||||||
|
str *mockquic.MockStream
|
||||||
|
strBuf *bytes.Buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
decode := func(str io.Reader) map[string]string {
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
headersFrame := frame.(*headersFrame)
|
||||||
|
data := make([]byte, headersFrame.Length)
|
||||||
|
_, err = io.ReadFull(str, data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
decoder := qpack.NewDecoder(nil)
|
||||||
|
hfs, err := decoder.DecodeFull(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
values := make(map[string]string)
|
||||||
|
for _, hf := range hfs {
|
||||||
|
values[hf.Name] = hf.Value
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
rw = newRequestWriter(utils.DefaultLogger)
|
||||||
|
strBuf = &bytes.Buffer{}
|
||||||
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return strBuf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes a GET request", func() {
|
||||||
|
str.EXPECT().Close()
|
||||||
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rw.WriteRequest(str, req)).To(Succeed())
|
||||||
|
headerFields := decode(strBuf)
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
|
||||||
|
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes a POST request", func() {
|
||||||
|
closed := make(chan struct{})
|
||||||
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
|
postData := bytes.NewReader([]byte("foobar"))
|
||||||
|
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rw.WriteRequest(str, req)).To(Succeed())
|
||||||
|
headerFields := decode(strBuf)
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||||
|
Expect(headerFields).To(HaveKey("content-length"))
|
||||||
|
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(contentLength).To(BeNumerically(">", 0))
|
||||||
|
|
||||||
|
Eventually(closed).Should(BeClosed())
|
||||||
|
frame, err := parseNextFrame(strBuf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sends cookies", func() {
|
||||||
|
str.EXPECT().Close()
|
||||||
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
cookie1 := &http.Cookie{
|
||||||
|
Name: "Cookie #1",
|
||||||
|
Value: "Value #1",
|
||||||
|
}
|
||||||
|
cookie2 := &http.Cookie{
|
||||||
|
Name: "Cookie #2",
|
||||||
|
Value: "Value #2",
|
||||||
|
}
|
||||||
|
req.AddCookie(cookie1)
|
||||||
|
req.AddCookie(cookie2)
|
||||||
|
Expect(rw.WriteRequest(str, req)).To(Succeed())
|
||||||
|
headerFields := decode(strBuf)
|
||||||
|
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
|
||||||
|
})
|
||||||
|
})
|
|
@ -1,4 +1,4 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
|
@ -1,7 +1,8 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"github.com/golang/mock/gomock"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -9,21 +10,17 @@ import (
|
||||||
|
|
||||||
var _ = Describe("Response Body", func() {
|
var _ = Describe("Response Body", func() {
|
||||||
var (
|
var (
|
||||||
stream *mockStream
|
stream *mockquic.MockStream
|
||||||
body *responseBody
|
body *responseBody
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
stream = newMockStream(42)
|
stream = mockquic.NewMockStream(mockCtrl)
|
||||||
body = &responseBody{stream}
|
body = &responseBody{stream}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("calls CancelRead when closing", func() {
|
It("calls CancelRead when closing", func() {
|
||||||
stream.dataToRead = *bytes.NewBuffer([]byte("foobar"))
|
stream.EXPECT().CancelRead(gomock.Any())
|
||||||
n, err := body.Read(make([]byte, 3))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(body.Close()).To(Succeed())
|
Expect(body.Close()).To(Succeed())
|
||||||
Expect(stream.canceledRead).To(BeTrue())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
|
@ -1,25 +1,18 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"golang.org/x/net/http2"
|
"github.com/marten-seemann/qpack"
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
dataStreamID protocol.StreamID
|
stream io.Writer
|
||||||
dataStream quic.Stream
|
|
||||||
|
|
||||||
headerStream quic.Stream
|
|
||||||
headerStreamMutex *sync.Mutex
|
|
||||||
|
|
||||||
header http.Header
|
header http.Header
|
||||||
status int // status code passed to WriteHeader
|
status int // status code passed to WriteHeader
|
||||||
|
@ -28,20 +21,13 @@ type responseWriter struct {
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResponseWriter(
|
var _ http.ResponseWriter = &responseWriter{}
|
||||||
headerStream quic.Stream,
|
|
||||||
headerStreamMutex *sync.Mutex,
|
func newResponseWriter(stream io.Writer, logger utils.Logger) *responseWriter {
|
||||||
dataStream quic.Stream,
|
|
||||||
dataStreamID protocol.StreamID,
|
|
||||||
logger utils.Logger,
|
|
||||||
) *responseWriter {
|
|
||||||
return &responseWriter{
|
return &responseWriter{
|
||||||
header: http.Header{},
|
header: http.Header{},
|
||||||
headerStream: headerStream,
|
stream: stream,
|
||||||
headerStreamMutex: headerStreamMutex,
|
logger: logger,
|
||||||
dataStream: dataStream,
|
|
||||||
dataStreamID: dataStreamID,
|
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,26 +43,23 @@ func (w *responseWriter) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
|
|
||||||
var headers bytes.Buffer
|
var headers bytes.Buffer
|
||||||
enc := hpack.NewEncoder(&headers)
|
enc := qpack.NewEncoder(&headers)
|
||||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||||
|
|
||||||
for k, v := range w.header {
|
for k, v := range w.header {
|
||||||
for index := range v {
|
for index := range v {
|
||||||
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&headersFrame{Length: uint64(headers.Len())}).Write(buf)
|
||||||
w.logger.Infof("Responding with %d", status)
|
w.logger.Infof("Responding with %d", status)
|
||||||
w.headerStreamMutex.Lock()
|
if _, err := w.stream.Write(buf.Bytes()); err != nil {
|
||||||
defer w.headerStreamMutex.Unlock()
|
w.logger.Errorf("could not write headers frame: %s", err.Error())
|
||||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
}
|
||||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
if _, err := w.stream.Write(headers.Bytes()); err != nil {
|
||||||
StreamID: uint32(w.dataStreamID),
|
w.logger.Errorf("could not write header frame payload: %s", err.Error())
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: headers.Bytes(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
w.logger.Errorf("could not write h2 header: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +70,13 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
||||||
if !bodyAllowedForStatus(w.status) {
|
if !bodyAllowedForStatus(w.status) {
|
||||||
return 0, http.ErrBodyNotAllowed
|
return 0, http.ErrBodyNotAllowed
|
||||||
}
|
}
|
||||||
return w.dataStream.Write(p)
|
df := &dataFrame{Length: uint64(len(p))}
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
df.Write(buf)
|
||||||
|
if _, err := w.stream.Write(buf.Bytes()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return w.stream.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseWriter) Flush() {}
|
func (w *responseWriter) Flush() {}
|
|
@ -1,4 +1,4 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
120
http3/response_writer_test.go
Normal file
120
http3/response_writer_test.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Response Writer", func() {
|
||||||
|
var (
|
||||||
|
rw *responseWriter
|
||||||
|
strBuf *bytes.Buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
strBuf = &bytes.Buffer{}
|
||||||
|
rw = newResponseWriter(strBuf, utils.DefaultLogger)
|
||||||
|
})
|
||||||
|
|
||||||
|
decodeHeader := func(str io.Reader) map[string][]string {
|
||||||
|
fields := make(map[string][]string)
|
||||||
|
decoder := qpack.NewDecoder(nil)
|
||||||
|
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
headersFrame := frame.(*headersFrame)
|
||||||
|
data := make([]byte, headersFrame.Length)
|
||||||
|
_, err = io.ReadFull(str, data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hfs, err := decoder.DecodeFull(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
for _, p := range hfs {
|
||||||
|
fields[p.Name] = append(fields[p.Name], p.Value)
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
getData := func(str io.Reader) []byte {
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&dataFrame{}))
|
||||||
|
df := frame.(*dataFrame)
|
||||||
|
data := make([]byte, df.Length)
|
||||||
|
_, err = io.ReadFull(str, data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
It("writes status", func() {
|
||||||
|
rw.WriteHeader(http.StatusTeapot)
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveLen(1))
|
||||||
|
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes headers", func() {
|
||||||
|
rw.Header().Add("content-length", "42")
|
||||||
|
rw.WriteHeader(http.StatusTeapot)
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes multiple headers with the same name", func() {
|
||||||
|
const cookie1 = "test1=1; Max-Age=7200; path=/"
|
||||||
|
const cookie2 = "test2=2; Max-Age=7200; path=/"
|
||||||
|
rw.Header().Add("set-cookie", cookie1)
|
||||||
|
rw.Header().Add("set-cookie", cookie2)
|
||||||
|
rw.WriteHeader(http.StatusTeapot)
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveKey("set-cookie"))
|
||||||
|
cookies := fields["set-cookie"]
|
||||||
|
Expect(cookies).To(ContainElement(cookie1))
|
||||||
|
Expect(cookies).To(ContainElement(cookie2))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes data", func() {
|
||||||
|
n, err := rw.Write([]byte("foobar"))
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// Should have written 200 on the header stream
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
|
// And foobar on the data stream
|
||||||
|
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes data after WriteHeader is called", func() {
|
||||||
|
rw.WriteHeader(http.StatusTeapot)
|
||||||
|
n, err := rw.Write([]byte("foobar"))
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// Should have written 418 on the header stream
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
||||||
|
// And foobar on the data stream
|
||||||
|
Expect(getData(strBuf)).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not WriteHeader() twice", func() {
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
rw.WriteHeader(500)
|
||||||
|
fields := decodeHeader(strBuf)
|
||||||
|
Expect(fields).To(HaveLen(1))
|
||||||
|
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't allow writes if the status code doesn't allow a body", func() {
|
||||||
|
rw.WriteHeader(304)
|
||||||
|
n, err := rw.Write([]byte("foobar"))
|
||||||
|
Expect(n).To(BeZero())
|
||||||
|
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
|
||||||
|
})
|
||||||
|
})
|
|
@ -1,4 +1,4 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -61,42 +61,42 @@ type RoundTripOpt struct {
|
||||||
var _ roundTripCloser = &RoundTripper{}
|
var _ roundTripCloser = &RoundTripper{}
|
||||||
|
|
||||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||||
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
|
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||||
|
|
||||||
// RoundTripOpt is like RoundTrip, but takes options.
|
// RoundTripOpt is like RoundTrip, but takes options.
|
||||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||||
if req.URL == nil {
|
if req.URL == nil {
|
||||||
closeRequestBody(req)
|
closeRequestBody(req)
|
||||||
return nil, errors.New("quic: nil Request.URL")
|
return nil, errors.New("http3: nil Request.URL")
|
||||||
}
|
}
|
||||||
if req.URL.Host == "" {
|
if req.URL.Host == "" {
|
||||||
closeRequestBody(req)
|
closeRequestBody(req)
|
||||||
return nil, errors.New("quic: no Host in request URL")
|
return nil, errors.New("http3: no Host in request URL")
|
||||||
}
|
}
|
||||||
if req.Header == nil {
|
if req.Header == nil {
|
||||||
closeRequestBody(req)
|
closeRequestBody(req)
|
||||||
return nil, errors.New("quic: nil Request.Header")
|
return nil, errors.New("http3: nil Request.Header")
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.URL.Scheme == "https" {
|
if req.URL.Scheme == "https" {
|
||||||
for k, vv := range req.Header {
|
for k, vv := range req.Header {
|
||||||
if !httpguts.ValidHeaderFieldName(k) {
|
if !httpguts.ValidHeaderFieldName(k) {
|
||||||
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
|
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
|
||||||
}
|
}
|
||||||
for _, v := range vv {
|
for _, v := range vv {
|
||||||
if !httpguts.ValidHeaderFieldValue(v) {
|
if !httpguts.ValidHeaderFieldValue(v) {
|
||||||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
|
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
closeRequestBody(req)
|
closeRequestBody(req)
|
||||||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
|
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Method != "" && !validMethod(req.Method) {
|
if req.Method != "" && !validMethod(req.Method) {
|
||||||
closeRequestBody(req)
|
closeRequestBody(req)
|
||||||
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
|
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
hostname := authorityAddr("https", hostnameFromRequest(req))
|
|
@ -1,4 +1,4 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -8,7 +8,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
@ -34,6 +36,9 @@ type mockBody struct {
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make sure the mockBody can be used as a http.Request.Body
|
||||||
|
var _ io.ReadCloser = &mockBody{}
|
||||||
|
|
||||||
func (m *mockBody) Read(p []byte) (int, error) {
|
func (m *mockBody) Read(p []byte) (int, error) {
|
||||||
if m.readErr != nil {
|
if m.readErr != nil {
|
||||||
return 0, m.readErr
|
return 0, m.readErr
|
||||||
|
@ -50,13 +55,11 @@ func (m *mockBody) Close() error {
|
||||||
return m.closeErr
|
return m.closeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure the mockBody can be used as a http.Request.Body
|
|
||||||
var _ io.ReadCloser = &mockBody{}
|
|
||||||
|
|
||||||
var _ = Describe("RoundTripper", func() {
|
var _ = Describe("RoundTripper", func() {
|
||||||
var (
|
var (
|
||||||
rt *RoundTripper
|
rt *RoundTripper
|
||||||
req1 *http.Request
|
req1 *http.Request
|
||||||
|
session *mockquic.MockSession
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -68,14 +71,14 @@ var _ = Describe("RoundTripper", func() {
|
||||||
|
|
||||||
Context("dialing hosts", func() {
|
Context("dialing hosts", func() {
|
||||||
origDialAddr := dialAddr
|
origDialAddr := dialAddr
|
||||||
streamOpenErr := errors.New("error opening stream")
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
session = mockquic.NewMockSession(mockCtrl)
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||||
// return an error when trying to open a stream
|
// return an error when trying to open a stream
|
||||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||||
return &mockSession{streamOpenErr: streamOpenErr}, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -84,11 +87,17 @@ var _ = Describe("RoundTripper", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates new clients", func() {
|
It("creates new clients", func() {
|
||||||
|
closed := make(chan struct{})
|
||||||
|
testErr := errors.New("test err")
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
|
||||||
|
session.EXPECT().OpenStreamSync().Return(nil, testErr)
|
||||||
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, error) { close(closed) })
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
|
Eventually(closed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("uses the quic.Config, if provided", func() {
|
It("uses the quic.Config, if provided", func() {
|
||||||
|
@ -96,35 +105,43 @@ var _ = Describe("RoundTripper", func() {
|
||||||
var receivedConfig *quic.Config
|
var receivedConfig *quic.Config
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||||
receivedConfig = config
|
receivedConfig = config
|
||||||
return nil, errors.New("err")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
rt.QuicConfig = config
|
rt.QuicConfig = config
|
||||||
rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(receivedConfig).To(Equal(config))
|
Expect(err).To(MatchError("handshake error"))
|
||||||
|
Expect(receivedConfig.HandshakeTimeout).To(Equal(config.HandshakeTimeout))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
It("uses the custom dialer, if provided", func() {
|
||||||
var dialed bool
|
var dialed bool
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||||
dialed = true
|
dialed = true
|
||||||
return nil, errors.New("err")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
rt.Dial = dialer
|
rt.Dial = dialer
|
||||||
rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
|
Expect(err).To(MatchError("handshake error"))
|
||||||
Expect(dialed).To(BeTrue())
|
Expect(dialed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reuses existing clients", func() {
|
It("reuses existing clients", func() {
|
||||||
|
closed := make(chan struct{})
|
||||||
|
testErr := errors.New("test err")
|
||||||
|
session.EXPECT().OpenUniStreamSync().AnyTimes().Return(nil, testErr)
|
||||||
|
session.EXPECT().OpenStreamSync().Return(nil, testErr).Times(2)
|
||||||
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, error) { close(closed) })
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = rt.RoundTrip(req2)
|
_, err = rt.RoundTrip(req2)
|
||||||
Expect(err).To(MatchError(streamOpenErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(rt.clients).To(HaveLen(1))
|
Expect(rt.clients).To(HaveLen(1))
|
||||||
|
Eventually(closed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
||||||
|
@ -141,7 +158,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
req.Body = &mockBody{}
|
req.Body = &mockBody{}
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
|
Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
|
||||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -149,7 +166,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
req1.URL = nil
|
req1.URL = nil
|
||||||
req1.Body = &mockBody{}
|
req1.Body = &mockBody{}
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -157,7 +174,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
req1.URL.Host = ""
|
req1.URL.Host = ""
|
||||||
req1.Body = &mockBody{}
|
req1.Body = &mockBody{}
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: no Host in request URL"))
|
Expect(err).To(MatchError("http3: no Host in request URL"))
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -165,34 +182,34 @@ var _ = Describe("RoundTripper", func() {
|
||||||
req1.URL = nil
|
req1.URL = nil
|
||||||
Expect(req1.Body).To(BeNil())
|
Expect(req1.Body).To(BeNil())
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects requests without a header", func() {
|
It("rejects requests without a header", func() {
|
||||||
req1.Header = nil
|
req1.Header = nil
|
||||||
req1.Body = &mockBody{}
|
req1.Body = &mockBody{}
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: nil Request.Header"))
|
Expect(err).To(MatchError("http3: nil Request.Header"))
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects requests with invalid header name fields", func() {
|
It("rejects requests with invalid header name fields", func() {
|
||||||
req1.Header.Add("foobär", "value")
|
req1.Header.Add("foobär", "value")
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
|
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects requests with invalid header name values", func() {
|
It("rejects requests with invalid header name values", func() {
|
||||||
req1.Header.Add("foo", string([]byte{0x7}))
|
req1.Header.Add("foo", string([]byte{0x7}))
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
|
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects requests with an invalid request method", func() {
|
It("rejects requests with an invalid request method", func() {
|
||||||
req1.Method = "foobär"
|
req1.Method = "foobär"
|
||||||
req1.Body = &mockBody{}
|
req1.Body = &mockBody{}
|
||||||
_, err := rt.RoundTrip(req1)
|
_, err := rt.RoundTrip(req1)
|
||||||
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
|
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
|
||||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
|
@ -1,9 +1,10 @@
|
||||||
package h2quic
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -12,23 +13,12 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"golang.org/x/net/http2"
|
"github.com/marten-seemann/qpack"
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamCreator interface {
|
|
||||||
quic.Session
|
|
||||||
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type remoteCloser interface {
|
|
||||||
CloseRemote(protocol.ByteCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// allows mocking of quic.Listen and quic.ListenAddr
|
// allows mocking of quic.Listen and quic.ListenAddr
|
||||||
var (
|
var (
|
||||||
quicListen = quic.Listen
|
quicListen = quic.Listen
|
||||||
|
@ -43,9 +33,6 @@ type Server struct {
|
||||||
// If nil, it uses reasonable default values.
|
// If nil, it uses reasonable default values.
|
||||||
QuicConfig *quic.Config
|
QuicConfig *quic.Config
|
||||||
|
|
||||||
// Private flag for demo, do not use
|
|
||||||
CloseAfterFirstRequest bool
|
|
||||||
|
|
||||||
port uint32 // used atomically
|
port uint32 // used atomically
|
||||||
|
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
|
@ -54,18 +41,18 @@ type Server struct {
|
||||||
|
|
||||||
supportedVersionsAsString string
|
supportedVersionsAsString string
|
||||||
|
|
||||||
logger utils.Logger // will be set by Server.serveImpl()
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||||
func (s *Server) ListenAndServe() error {
|
func (s *Server) ListenAndServe() error {
|
||||||
if s.Server == nil {
|
if s.Server == nil {
|
||||||
return errors.New("use of h2quic.Server without http.Server")
|
return errors.New("use of http3.Server without http.Server")
|
||||||
}
|
}
|
||||||
return s.serveImpl(s.TLSConfig, nil)
|
return s.serveImpl(s.TLSConfig, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||||
var err error
|
var err error
|
||||||
certs := make([]tls.Certificate, 1)
|
certs := make([]tls.Certificate, 1)
|
||||||
|
@ -88,7 +75,7 @@ func (s *Server) Serve(conn net.PacketConn) error {
|
||||||
|
|
||||||
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||||
if s.Server == nil {
|
if s.Server == nil {
|
||||||
return errors.New("use of h2quic.Server without http.Server")
|
return errors.New("use of http3.Server without http.Server")
|
||||||
}
|
}
|
||||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||||
s.listenerMutex.Lock()
|
s.listenerMutex.Lock()
|
||||||
|
@ -120,138 +107,104 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go s.handleHeaderStream(sess.(streamCreator))
|
go s.handleConn(sess)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
func (s *Server) handleConn(sess quic.Session) {
|
||||||
stream, err := session.AcceptStream()
|
// TODO: accept control streams
|
||||||
if err != nil {
|
decoder := qpack.NewDecoder(nil)
|
||||||
session.CloseWithError(quic.ErrorCode(qerr.InternalError), err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
|
||||||
h2framer := http2.NewFramer(nil, stream)
|
|
||||||
|
|
||||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
|
||||||
for {
|
for {
|
||||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
str, err := sess.AcceptStream()
|
||||||
// QuicErrors must originate from stream.Read() returning an error.
|
if err != nil {
|
||||||
// In this case, the session has already logged the error, so we don't
|
s.logger.Debugf("Accepting stream failed: %s", err)
|
||||||
// need to log it again.
|
|
||||||
errorCode := qerr.InternalError
|
|
||||||
if qerr, ok := err.(*qerr.QuicError); ok {
|
|
||||||
errorCode = qerr.ErrorCode
|
|
||||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
|
||||||
}
|
|
||||||
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// TODO: handle error
|
||||||
|
go func() {
|
||||||
|
if err := s.handleRequest(str, decoder); err != nil {
|
||||||
|
s.logger.Debugf("Handling request failed: %s", err)
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorGeneralProtocolError))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
str.Close()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
// TODO: improve error handling.
|
||||||
h2frame, err := h2framer.ReadFrame()
|
// Most (but not all) of the errors occurring here are connection-level erros.
|
||||||
|
func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) error {
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return qerr.Error(qerr.InternalError, "cannot read frame")
|
str.CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||||
}
|
|
||||||
var h2headersFrame *http2.HeadersFrame
|
|
||||||
switch f := h2frame.(type) {
|
|
||||||
case *http2.PriorityFrame:
|
|
||||||
// ignore PRIORITY frames
|
|
||||||
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
|
|
||||||
return nil
|
|
||||||
case *http2.HeadersFrame:
|
|
||||||
h2headersFrame = f
|
|
||||||
default:
|
|
||||||
return qerr.Error(qerr.ProtocolViolation, "expected a header frame")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !h2headersFrame.HeadersEnded() {
|
|
||||||
return errors.New("http2 header continuation not implemented")
|
|
||||||
}
|
|
||||||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
hf, ok := frame.(*headersFrame)
|
||||||
req, err := requestFromHeaders(headers)
|
if !ok {
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorUnexpectedFrame))
|
||||||
|
return errors.New("expected first frame to be a headers frame")
|
||||||
|
}
|
||||||
|
// TODO: check length
|
||||||
|
headerBlock := make([]byte, hf.Length)
|
||||||
|
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorIncompleteRequest))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hfs, err := decoder.DecodeFull(headerBlock)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: use the right error code
|
||||||
|
str.CancelWrite(quic.ErrorCode(errorGeneralProtocolError))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req, err := requestFromHeaders(hfs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
req.Body = newRequestBody(str)
|
||||||
|
|
||||||
if s.logger.Debug() {
|
if s.logger.Debug() {
|
||||||
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||||
} else {
|
} else {
|
||||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
|
req = req.WithContext(str.Context())
|
||||||
if err != nil {
|
responseWriter := newResponseWriter(str, s.logger)
|
||||||
return err
|
handler := s.Handler
|
||||||
}
|
if handler == nil {
|
||||||
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
|
handler = http.DefaultServeMux
|
||||||
if dataStream == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequest should be as non-blocking as possible to minimize
|
var panicked, readEOF bool
|
||||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
func() {
|
||||||
// goroutine, enabling handleRequest to return before the code is executed.
|
defer func() {
|
||||||
go func() {
|
if p := recover(); p != nil {
|
||||||
streamEnded := h2headersFrame.StreamEnded()
|
// Copied from net/http/server.go
|
||||||
if streamEnded {
|
const size = 64 << 10
|
||||||
dataStream.(remoteCloser).CloseRemote(0)
|
buf := make([]byte, size)
|
||||||
streamEnded = true
|
buf = buf[:runtime.Stack(buf, false)]
|
||||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||||
}
|
panicked = true
|
||||||
|
|
||||||
req = req.WithContext(dataStream.Context())
|
|
||||||
reqBody := newRequestBody(dataStream)
|
|
||||||
req.Body = reqBody
|
|
||||||
|
|
||||||
req.RemoteAddr = session.RemoteAddr().String()
|
|
||||||
|
|
||||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
|
|
||||||
|
|
||||||
handler := s.Handler
|
|
||||||
if handler == nil {
|
|
||||||
handler = http.DefaultServeMux
|
|
||||||
}
|
|
||||||
panicked := false
|
|
||||||
func() {
|
|
||||||
defer func() {
|
|
||||||
if p := recover(); p != nil {
|
|
||||||
// Copied from net/http/server.go
|
|
||||||
const size = 64 << 10
|
|
||||||
buf := make([]byte, size)
|
|
||||||
buf = buf[:runtime.Stack(buf, false)]
|
|
||||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
|
||||||
panicked = true
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
handler.ServeHTTP(responseWriter, req)
|
|
||||||
}()
|
|
||||||
if panicked {
|
|
||||||
responseWriter.WriteHeader(500)
|
|
||||||
} else {
|
|
||||||
responseWriter.WriteHeader(200)
|
|
||||||
}
|
|
||||||
if responseWriter.dataStream != nil {
|
|
||||||
if !streamEnded && !reqBody.requestRead {
|
|
||||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
|
||||||
responseWriter.dataStream.CancelRead(0)
|
|
||||||
}
|
}
|
||||||
responseWriter.dataStream.Close()
|
}()
|
||||||
}
|
handler.ServeHTTP(responseWriter, req)
|
||||||
if s.CloseAfterFirstRequest {
|
// read the eof
|
||||||
time.Sleep(100 * time.Millisecond)
|
if _, err = str.Read([]byte{}); err == io.EOF {
|
||||||
session.Close()
|
readEOF = true
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if panicked {
|
||||||
|
responseWriter.WriteHeader(500)
|
||||||
|
} else {
|
||||||
|
responseWriter.WriteHeader(200)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !readEOF {
|
||||||
|
str.CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,7 +263,7 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||||
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
|
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
|
||||||
// used when handler is nil.
|
// used when handler is nil.
|
||||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||||
server := &Server{
|
server := &Server{
|
408
http3/server_test.go
Normal file
408
http3/server_test.go
Normal file
|
@ -0,0 +1,408 @@
|
||||||
|
package http3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Server", func() {
|
||||||
|
var (
|
||||||
|
s *Server
|
||||||
|
// session *mockquic.MockSession
|
||||||
|
origQuicListenAddr = quicListenAddr
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
s = &Server{
|
||||||
|
Server: &http.Server{
|
||||||
|
TLSConfig: testdata.GetTLSConfig(),
|
||||||
|
},
|
||||||
|
logger: utils.DefaultLogger,
|
||||||
|
}
|
||||||
|
origQuicListenAddr = quicListenAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
quicListenAddr = origQuicListenAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("handling requests", func() {
|
||||||
|
var (
|
||||||
|
qpackDecoder *qpack.Decoder
|
||||||
|
str *mockquic.MockStream
|
||||||
|
exampleGetRequest *http.Request
|
||||||
|
examplePostRequest *http.Request
|
||||||
|
)
|
||||||
|
reqContext := context.Background()
|
||||||
|
|
||||||
|
decodeHeader := func(str io.Reader) map[string][]string {
|
||||||
|
fields := make(map[string][]string)
|
||||||
|
decoder := qpack.NewDecoder(nil)
|
||||||
|
|
||||||
|
frame, err := parseNextFrame(str)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
||||||
|
headersFrame := frame.(*headersFrame)
|
||||||
|
data := make([]byte, headersFrame.Length)
|
||||||
|
_, err = io.ReadFull(str, data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hfs, err := decoder.DecodeFull(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
for _, p := range hfs {
|
||||||
|
fields[p.Name] = append(fields[p.Name], p.Value)
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
encodeRequest := func(req *http.Request) []byte {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
str := mockquic.NewMockStream(mockCtrl)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return buf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
closed := make(chan struct{})
|
||||||
|
str.EXPECT().Close().Do(func() { close(closed) })
|
||||||
|
rw := newRequestWriter(utils.DefaultLogger)
|
||||||
|
Expect(rw.WriteRequest(str, req)).To(Succeed())
|
||||||
|
if req.Body != nil {
|
||||||
|
b := make([]byte, 1000)
|
||||||
|
n, err := io.ReadFull(req.Body, b)
|
||||||
|
Expect(err).To(Equal(io.ErrUnexpectedEOF)) // otherwise b is too small for this test
|
||||||
|
(&dataFrame{Length: uint64(n)}).Write(buf)
|
||||||
|
buf.Write(b[:n])
|
||||||
|
}
|
||||||
|
Eventually(closed).Should(BeClosed())
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
setRequest := func(data []byte) {
|
||||||
|
buf := bytes.NewBuffer(data)
|
||||||
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return buf.Read(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar")))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
qpackDecoder = qpack.NewDecoder(nil)
|
||||||
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("calls the HTTP handler function", func() {
|
||||||
|
requestChan := make(chan *http.Request, 1)
|
||||||
|
s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
requestChan <- r
|
||||||
|
})
|
||||||
|
|
||||||
|
setRequest(encodeRequest(exampleGetRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
var req *http.Request
|
||||||
|
Eventually(requestChan).Should(Receive(&req))
|
||||||
|
Expect(req.Host).To(Equal("www.example.com"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns 200 with an empty handler", func() {
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
responseBuf := &bytes.Buffer{}
|
||||||
|
setRequest(encodeRequest(exampleGetRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return responseBuf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
hfs := decodeHeader(responseBuf)
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handles a panicking handler", func() {
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("foobar")
|
||||||
|
})
|
||||||
|
|
||||||
|
responseBuf := &bytes.Buffer{}
|
||||||
|
setRequest(encodeRequest(exampleGetRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return responseBuf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
str.EXPECT().CancelRead(gomock.Any())
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
hfs := decodeHeader(responseBuf)
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("cancels reading when client sends a body in GET request", func() {
|
||||||
|
handlerCalled := make(chan struct{})
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(handlerCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
requestData := encodeRequest(exampleGetRequest)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&dataFrame{Length: 6}).Write(buf) // add a body
|
||||||
|
buf.Write([]byte("foobar"))
|
||||||
|
responseBuf := &bytes.Buffer{}
|
||||||
|
setRequest(append(requestData, buf.Bytes()...))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return responseBuf.Write(p)
|
||||||
|
}).AnyTimes()
|
||||||
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
hfs := decodeHeader(responseBuf)
|
||||||
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("cancels reading when the body of POST request is not read", func() {
|
||||||
|
handlerCalled := make(chan struct{})
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Expect(r.Host).To(Equal("www.example.com"))
|
||||||
|
Expect(r.Method).To(Equal("POST"))
|
||||||
|
close(handlerCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
setRequest(encodeRequest(examplePostRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}).AnyTimes()
|
||||||
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("handles a request for which the client immediately resets the stream", func() {
|
||||||
|
handlerCalled := make(chan struct{})
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(handlerCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
testErr := errors.New("stream reset")
|
||||||
|
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
|
||||||
|
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled))
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(MatchError(testErr))
|
||||||
|
Consistently(handlerCalled).ShouldNot(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
|
||||||
|
handlerCalled := make(chan struct{})
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.Body = struct {
|
||||||
|
io.Reader
|
||||||
|
io.Closer
|
||||||
|
}{}
|
||||||
|
close(handlerCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
setRequest(encodeRequest(examplePostRequest))
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}).AnyTimes()
|
||||||
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("cancels the request context when the stream is closed", func() {
|
||||||
|
handlerCalled := make(chan struct{})
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(r.Context().Done()).To(BeClosed())
|
||||||
|
Expect(r.Context().Err()).To(MatchError(context.Canceled))
|
||||||
|
close(handlerCalled)
|
||||||
|
})
|
||||||
|
setRequest(encodeRequest(examplePostRequest))
|
||||||
|
|
||||||
|
reqContext, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
str.EXPECT().Context().Return(reqContext)
|
||||||
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}).AnyTimes()
|
||||||
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
|
Expect(s.handleRequest(str, qpackDecoder)).To(Succeed())
|
||||||
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("setting http headers", func() {
|
||||||
|
var expected http.Header
|
||||||
|
|
||||||
|
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
|
||||||
|
var versionsAsString []string
|
||||||
|
for _, v := range versions {
|
||||||
|
versionsAsString = append(versionsAsString, v.ToAltSvc())
|
||||||
|
}
|
||||||
|
return http.Header{
|
||||||
|
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
|
||||||
|
expected = getExpectedHeader(protocol.SupportedVersions)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets proper headers with numeric port", func() {
|
||||||
|
s.Server.Addr = ":443"
|
||||||
|
hdr := http.Header{}
|
||||||
|
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
|
||||||
|
Expect(hdr).To(Equal(expected))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets proper headers with full addr", func() {
|
||||||
|
s.Server.Addr = "127.0.0.1:443"
|
||||||
|
hdr := http.Header{}
|
||||||
|
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
|
||||||
|
Expect(hdr).To(Equal(expected))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets proper headers with string port", func() {
|
||||||
|
s.Server.Addr = ":https"
|
||||||
|
hdr := http.Header{}
|
||||||
|
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
|
||||||
|
Expect(hdr).To(Equal(expected))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("works multiple times", func() {
|
||||||
|
s.Server.Addr = ":https"
|
||||||
|
hdr := http.Header{}
|
||||||
|
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
|
||||||
|
Expect(hdr).To(Equal(expected))
|
||||||
|
hdr = http.Header{}
|
||||||
|
Expect(s.SetQuicHeaders(hdr)).To(Succeed())
|
||||||
|
Expect(hdr).To(Equal(expected))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when ListenAndServe is called with s.Server nil", func() {
|
||||||
|
Expect((&Server{}).ListenAndServe()).To(MatchError("use of http3.Server without http.Server"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when ListenAndServeTLS is called with s.Server nil", func() {
|
||||||
|
Expect((&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError("use of http3.Server without http.Server"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should nop-Close() when s.server is nil", func() {
|
||||||
|
Expect((&Server{}).Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when ListenAndServer is called after Close", func() {
|
||||||
|
serv := &Server{Server: &http.Server{}}
|
||||||
|
Expect(serv.Close()).To(Succeed())
|
||||||
|
Expect(serv.ListenAndServe()).To(MatchError("Server is already closed"))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("ListenAndServe", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
s.Server.Addr = "localhost:0"
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
Expect(s.Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("may only be called once", func() {
|
||||||
|
cErr := make(chan error)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
if err := s.ListenAndServe(); err != nil {
|
||||||
|
cErr <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once")))
|
||||||
|
Expect(s.Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("uses the quic.Config to start the quic server", func() {
|
||||||
|
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||||
|
var receivedConf *quic.Config
|
||||||
|
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||||
|
receivedConf = config
|
||||||
|
return nil, errors.New("listen err")
|
||||||
|
}
|
||||||
|
s.QuicConfig = conf
|
||||||
|
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||||
|
Expect(receivedConf).To(Equal(conf))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("ListenAndServeTLS", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
s.Server.Addr = "localhost:0"
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
Expect(s.Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("may only be called once", func() {
|
||||||
|
cErr := make(chan error)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
if err := s.ListenAndServeTLS(testdata.GetCertificatePaths()); err != nil {
|
||||||
|
cErr <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once")))
|
||||||
|
Expect(s.Close()).To(Succeed())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes gracefully", func() {
|
||||||
|
Expect(s.CloseGracefully(0)).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when listening fails", func() {
|
||||||
|
testErr := errors.New("listen error")
|
||||||
|
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
fullpem, privkey := testdata.GetCertificatePaths()
|
||||||
|
Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr))
|
||||||
|
})
|
||||||
|
})
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/http3"
|
||||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
@ -38,7 +38,7 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
client = &http.Client{
|
client = &http.Client{
|
||||||
Transport: &h2quic.RoundTripper{
|
Transport: &http3.RoundTripper{
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
RootCAs: testdata.GetRootCA(),
|
RootCAs: testdata.GetRootCA(),
|
||||||
},
|
},
|
||||||
|
@ -77,8 +77,20 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Expect(body).To(Equal(testserver.PRDataLong))
|
Expect(body).To(Equal(testserver.PRDataLong))
|
||||||
})
|
})
|
||||||
|
|
||||||
// TODO(#1756): this test times out
|
It("downloads many hellos", func() {
|
||||||
PIt("downloads many files, if the response is not read", func() {
|
const num = 150
|
||||||
|
|
||||||
|
for i := 0; i < num; i++ {
|
||||||
|
resp, err := client.Get("https://localhost:" + testserver.Port() + "/hello")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(string(body)).To(Equal("Hello, World!\n"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("downloads many files, if the response is not read", func() {
|
||||||
const num = 150
|
const num = 150
|
||||||
|
|
||||||
for i := 0; i < num; i++ {
|
for i := 0; i < num; i++ {
|
||||||
|
@ -89,6 +101,19 @@ var _ = Describe("HTTP tests", func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("posts a small message", func() {
|
||||||
|
resp, err := client.Post(
|
||||||
|
"https://localhost:"+testserver.Port()+"/echo",
|
||||||
|
"text/plain",
|
||||||
|
bytes.NewReader([]byte("Hello, world!")),
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
|
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(body).To(Equal([]byte("Hello, world!")))
|
||||||
|
})
|
||||||
|
|
||||||
It("uploads a file", func() {
|
It("uploads a file", func() {
|
||||||
resp, err := client.Post(
|
resp, err := client.Post(
|
||||||
"https://localhost:"+testserver.Port()+"/echo",
|
"https://localhost:"+testserver.Port()+"/echo",
|
||||||
|
@ -99,7 +124,7 @@ var _ = Describe("HTTP tests", func() {
|
||||||
Expect(resp.StatusCode).To(Equal(200))
|
Expect(resp.StatusCode).To(Equal(200))
|
||||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue())
|
Expect(body).To(Equal(testserver.PRData))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lucas-clemente/quic-go/h2quic"
|
"github.com/lucas-clemente/quic-go/http3"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ var (
|
||||||
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
||||||
PRDataLong = GeneratePRData(dataLenLong)
|
PRDataLong = GeneratePRData(dataLenLong)
|
||||||
|
|
||||||
server *h2quic.Server
|
server *http3.Server
|
||||||
stoppedServing chan struct{}
|
stoppedServing chan struct{}
|
||||||
port string
|
port string
|
||||||
)
|
)
|
||||||
|
@ -75,10 +75,10 @@ func GeneratePRData(l int) []byte {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartQuicServer starts a h2quic.Server.
|
// StartQuicServer starts a http3.Server.
|
||||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||||
server = &h2quic.Server{
|
server = &http3.Server{
|
||||||
Server: &http.Server{
|
Server: &http.Server{
|
||||||
TLSConfig: testdata.GetTLSConfig(),
|
TLSConfig: testdata.GetTLSConfig(),
|
||||||
},
|
},
|
||||||
|
@ -102,7 +102,7 @@ func StartQuicServer(versions []protocol.VersionNumber) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopQuicServer stops the h2quic.Server.
|
// StopQuicServer stops the http3.Server.
|
||||||
func StopQuicServer() {
|
func StopQuicServer() {
|
||||||
Expect(server.Close()).NotTo(HaveOccurred())
|
Expect(server.Close()).NotTo(HaveOccurred())
|
||||||
Eventually(stoppedServing).Should(BeClosed())
|
Eventually(stoppedServing).Should(BeClosed())
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
|
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
|
||||||
|
//go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
|
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||||
|
|
213
internal/mocks/quic/session.go
Normal file
213
internal/mocks/quic/session.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go (interfaces: Session)
|
||||||
|
|
||||||
|
// Package mockquic is a generated GoMock package.
|
||||||
|
package mockquic
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
tls "crypto/tls"
|
||||||
|
net "net"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
quic_go "github.com/lucas-clemente/quic-go"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockSession is a mock of Session interface
|
||||||
|
type MockSession struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockSessionMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSessionMockRecorder is the mock recorder for MockSession
|
||||||
|
type MockSessionMockRecorder struct {
|
||||||
|
mock *MockSession
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockSession creates a new mock instance
|
||||||
|
func NewMockSession(ctrl *gomock.Controller) *MockSession {
|
||||||
|
mock := &MockSession{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockSessionMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockSession) EXPECT() *MockSessionMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptStream mocks base method
|
||||||
|
func (m *MockSession) AcceptStream() (quic_go.Stream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AcceptStream")
|
||||||
|
ret0, _ := ret[0].(quic_go.Stream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptStream indicates an expected call of AcceptStream
|
||||||
|
func (mr *MockSessionMockRecorder) AcceptStream() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockSession)(nil).AcceptStream))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptUniStream mocks base method
|
||||||
|
func (m *MockSession) AcceptUniStream() (quic_go.ReceiveStream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AcceptUniStream")
|
||||||
|
ret0, _ := ret[0].(quic_go.ReceiveStream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptUniStream indicates an expected call of AcceptUniStream
|
||||||
|
func (mr *MockSessionMockRecorder) AcceptUniStream() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockSession)(nil).AcceptUniStream))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mocks base method
|
||||||
|
func (m *MockSession) Close() error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Close")
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close
|
||||||
|
func (mr *MockSessionMockRecorder) Close() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSession)(nil).Close))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWithError mocks base method
|
||||||
|
func (m *MockSession) CloseWithError(arg0 protocol.ApplicationErrorCode, arg1 error) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWithError indicates an expected call of CloseWithError
|
||||||
|
func (mr *MockSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockSession)(nil).CloseWithError), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState mocks base method
|
||||||
|
func (m *MockSession) ConnectionState() tls.ConnectionState {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ConnectionState")
|
||||||
|
ret0, _ := ret[0].(tls.ConnectionState)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState indicates an expected call of ConnectionState
|
||||||
|
func (mr *MockSessionMockRecorder) ConnectionState() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockSession)(nil).ConnectionState))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context mocks base method
|
||||||
|
func (m *MockSession) Context() context.Context {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Context")
|
||||||
|
ret0, _ := ret[0].(context.Context)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context indicates an expected call of Context
|
||||||
|
func (mr *MockSessionMockRecorder) Context() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSession)(nil).Context))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr mocks base method
|
||||||
|
func (m *MockSession) LocalAddr() net.Addr {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LocalAddr")
|
||||||
|
ret0, _ := ret[0].(net.Addr)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr indicates an expected call of LocalAddr
|
||||||
|
func (mr *MockSessionMockRecorder) LocalAddr() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSession)(nil).LocalAddr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream mocks base method
|
||||||
|
func (m *MockSession) OpenStream() (quic_go.Stream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "OpenStream")
|
||||||
|
ret0, _ := ret[0].(quic_go.Stream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream indicates an expected call of OpenStream
|
||||||
|
func (mr *MockSessionMockRecorder) OpenStream() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockSession)(nil).OpenStream))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStreamSync mocks base method
|
||||||
|
func (m *MockSession) OpenStreamSync() (quic_go.Stream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "OpenStreamSync")
|
||||||
|
ret0, _ := ret[0].(quic_go.Stream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStreamSync indicates an expected call of OpenStreamSync
|
||||||
|
func (mr *MockSessionMockRecorder) OpenStreamSync() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockSession)(nil).OpenStreamSync))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenUniStream mocks base method
|
||||||
|
func (m *MockSession) OpenUniStream() (quic_go.SendStream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "OpenUniStream")
|
||||||
|
ret0, _ := ret[0].(quic_go.SendStream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenUniStream indicates an expected call of OpenUniStream
|
||||||
|
func (mr *MockSessionMockRecorder) OpenUniStream() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockSession)(nil).OpenUniStream))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenUniStreamSync mocks base method
|
||||||
|
func (m *MockSession) OpenUniStreamSync() (quic_go.SendStream, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "OpenUniStreamSync")
|
||||||
|
ret0, _ := ret[0].(quic_go.SendStream)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenUniStreamSync indicates an expected call of OpenUniStreamSync
|
||||||
|
func (mr *MockSessionMockRecorder) OpenUniStreamSync() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockSession)(nil).OpenUniStreamSync))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr mocks base method
|
||||||
|
func (m *MockSession) RemoteAddr() net.Addr {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RemoteAddr")
|
||||||
|
ret0, _ := ret[0].(net.Addr)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr indicates an expected call of RemoteAddr
|
||||||
|
func (mr *MockSessionMockRecorder) RemoteAddr() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSession)(nil).RemoteAddr))
|
||||||
|
}
|
175
internal/mocks/quic/stream.go
Normal file
175
internal/mocks/quic/stream.go
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go (interfaces: Stream)
|
||||||
|
|
||||||
|
// Package mockquic is a generated GoMock package.
|
||||||
|
package mockquic
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
reflect "reflect"
|
||||||
|
time "time"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockStream is a mock of Stream interface
|
||||||
|
type MockStream struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockStreamMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockStreamMockRecorder is the mock recorder for MockStream
|
||||||
|
type MockStreamMockRecorder struct {
|
||||||
|
mock *MockStream
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockStream creates a new mock instance
|
||||||
|
func NewMockStream(ctrl *gomock.Controller) *MockStream {
|
||||||
|
mock := &MockStream{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockStreamMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockStream) EXPECT() *MockStreamMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRead mocks base method
|
||||||
|
func (m *MockStream) CancelRead(arg0 protocol.ApplicationErrorCode) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "CancelRead", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRead indicates an expected call of CancelRead
|
||||||
|
func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelWrite mocks base method
|
||||||
|
func (m *MockStream) CancelWrite(arg0 protocol.ApplicationErrorCode) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "CancelWrite", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelWrite indicates an expected call of CancelWrite
|
||||||
|
func (mr *MockStreamMockRecorder) CancelWrite(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mocks base method
|
||||||
|
func (m *MockStream) Close() error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Close")
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close
|
||||||
|
func (mr *MockStreamMockRecorder) Close() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context mocks base method
|
||||||
|
func (m *MockStream) Context() context.Context {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Context")
|
||||||
|
ret0, _ := ret[0].(context.Context)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context indicates an expected call of Context
|
||||||
|
func (mr *MockStreamMockRecorder) Context() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read mocks base method
|
||||||
|
func (m *MockStream) Read(arg0 []byte) (int, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Read", arg0)
|
||||||
|
ret0, _ := ret[0].(int)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read indicates an expected call of Read
|
||||||
|
func (mr *MockStreamMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline mocks base method
|
||||||
|
func (m *MockStream) SetDeadline(arg0 time.Time) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetDeadline", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline indicates an expected call of SetDeadline
|
||||||
|
func (mr *MockStreamMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline mocks base method
|
||||||
|
func (m *MockStream) SetReadDeadline(arg0 time.Time) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetReadDeadline", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline indicates an expected call of SetReadDeadline
|
||||||
|
func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline mocks base method
|
||||||
|
func (m *MockStream) SetWriteDeadline(arg0 time.Time) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SetWriteDeadline", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline indicates an expected call of SetWriteDeadline
|
||||||
|
func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamID mocks base method
|
||||||
|
func (m *MockStream) StreamID() protocol.StreamID {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "StreamID")
|
||||||
|
ret0, _ := ret[0].(protocol.StreamID)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamID indicates an expected call of StreamID
|
||||||
|
func (mr *MockStreamMockRecorder) StreamID() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write mocks base method
|
||||||
|
func (m *MockStream) Write(arg0 []byte) (int, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Write", arg0)
|
||||||
|
ret0, _ := ret[0].(int)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write indicates an expected call of Write
|
||||||
|
func (mr *MockStreamMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue