privatize the client, only expose Dial functions

This commit is contained in:
Marten Seemann 2017-02-22 12:21:33 +07:00
parent 48dee2708e
commit 96edca5219
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
7 changed files with 214 additions and 188 deletions

121
client.go
View file

@ -6,6 +6,7 @@ import (
"errors"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@ -14,45 +15,35 @@ import (
"github.com/lucas-clemente/quic-go/utils"
)
// A Client of QUIC
type Client struct {
type client struct {
mutex sync.Mutex
connStateChangeCond sync.Cond
conn connection
hostname string
config *Config
connectionID protocol.ConnectionID
version protocol.VersionNumber
versionNegotiated bool
closed uint32 // atomic bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
versionNegotiateCallback VersionNegotiateCallback
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
session packetHandler
}
// VersionNegotiateCallback is called once the client has a negotiated version
type VersionNegotiateCallback func() error
var errHostname = errors.New("Invalid hostname")
var (
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
// NewClient makes a new client
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
udpAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
connectionID, err := utils.GenerateConnectionID()
// Dial establishes a new QUIC connection to a server
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
connID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
@ -62,28 +53,65 @@ func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoCh
return nil, err
}
client := &Client{
conn: &conn{pconn: udpConn, currentAddr: udpAddr},
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
tlsConfig: tlsConfig,
cryptoChangeCallback: cryptoChangeCallback,
versionNegotiateCallback: versionNegotiateCallback,
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
config: config,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
}
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version)
c.connStateChangeCond.L = &c.mutex
err = client.createNewSession(nil)
c.cryptoChangeCallback = func(isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure
} else {
state = ConnStateSecure
}
if c.config.ConnState != nil {
go config.ConnState(c.session, state)
}
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
}
return client, nil
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
// TODO: handle errors
go c.Listen()
c.mutex.Lock()
for !c.versionNegotiated {
c.connStateChangeCond.Wait()
}
c.mutex.Unlock()
return c.session, nil
}
// DialAddr establishes a new QUIC connection to a server
func DialAddr(hostname string, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", hostname)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, hostname, config)
}
// Listen listens
func (c *Client) Listen() error {
func (c *client) Listen() error {
for {
data := getPacketBuffer()
data = data[:protocol.MaxPacketSize]
@ -106,13 +134,8 @@ func (c *Client) Listen() error {
}
}
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
func (c *Client) OpenStream() (Stream, error) {
return c.session.OpenStream()
}
// Close closes the connection
func (c *Client) Close(e error) error {
func (c *client) Close(e error) error {
// Only close once
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
@ -122,7 +145,7 @@ func (c *Client) Close(e error) error {
return c.conn.Close()
}
func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
}
@ -145,10 +168,12 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.mutex.Lock()
c.versionNegotiated = true
err = c.versionNegotiateCallback()
if err != nil {
return err
c.connStateChangeCond.Signal()
c.mutex.Unlock()
if c.config.ConnState != nil {
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
}
}
@ -187,7 +212,9 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if err != nil {
return err
}
err = c.versionNegotiateCallback()
if c.config.ConnState != nil {
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
}
if err != nil {
return err
}
@ -204,14 +231,14 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
return nil
}
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, err = newClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.tlsConfig,
c.config.TLSConfig,
c.closeCallback,
c.cryptoChangeCallback,
negotiatedVersions)
@ -223,6 +250,6 @@ func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
return nil
}
func (c *Client) closeCallback(id protocol.ConnectionID) {
func (c *client) closeCallback(id protocol.ConnectionID) {
utils.Infof("Connection %x closed.", id)
}

View file

@ -19,44 +19,51 @@ import (
var _ = Describe("Client", func() {
var (
client *Client
sess *mockSession
packetConn *mockPacketConn
versionNegotiateCallbackCalled bool
cl *client
config *Config
sess *mockSession
packetConn *mockPacketConn
addr net.Addr
versionNegotiateConnStateCalled bool
)
BeforeEach(func() {
versionNegotiateConnStateCalled = false
packetConn = &mockPacketConn{}
versionNegotiateCallbackCalled = false
client = &Client{
versionNegotiateCallback: func() error {
versionNegotiateCallbackCalled = true
return nil
config = &Config{
ConnState: func(_ Session, state ConnState) {
if state == ConnStateVersionNegotiated {
versionNegotiateConnStateCalled = true
}
},
}
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
sess = &mockSession{connectionID: 0x1337}
client.connectionID = 0x1337
client.session = sess
client.version = protocol.Version36
client.conn = &conn{pconn: packetConn, currentAddr: addr}
cl = &client{
config: config,
connectionID: 0x1337,
session: sess,
version: protocol.Version36,
conn: &conn{pconn: packetConn, currentAddr: addr},
}
})
It("creates a new client", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
var err error
client, err = NewClient("quic.clemente.io:1337", nil, nil, nil)
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
Expect(client.hostname).To(Equal("quic.clemente.io"))
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
})
It("errors on invalid public header", func() {
err := client.handlePacket(nil, nil)
err := cl.handlePacket(nil, nil)
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("errors on large packets", func() {
err := client.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1))
err := cl.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1))
Expect(err).To(MatchError(qerr.PacketTooLarge))
})
@ -68,51 +75,45 @@ var _ = Describe("Client", func() {
var stoppedListening bool
go func() {
defer GinkgoRecover()
err := client.Listen()
err := cl.Listen()
Expect(err).ToNot(HaveOccurred())
stoppedListening = true
}()
err := client.Close(testErr)
err := cl.Close(testErr)
Expect(err).ToNot(HaveOccurred())
Eventually(sess.closed).Should(BeTrue())
Expect(sess.closeReason).To(MatchError(testErr))
Expect(client.closed).To(Equal(uint32(1)))
Expect(cl.closed).To(Equal(uint32(1)))
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
close(done)
}, 10)
It("only closes the client once", func() {
client.closed = 1
err := client.Close(errors.New("test error"))
cl.closed = 1
err := cl.Close(errors.New("test error"))
Expect(err).ToNot(HaveOccurred())
Eventually(sess.closed).Should(BeFalse())
Expect(sess.closeReason).ToNot(HaveOccurred())
})
It("creates new sessions with the right parameters", func() {
client.session = nil
client.hostname = "hostname"
err := client.createNewSession(nil)
cl.session = nil
cl.hostname = "hostname"
err := cl.createNewSession(nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.session).ToNot(BeNil())
Expect(client.session.(*session).connectionID).To(Equal(client.connectionID))
Expect(client.session.(*session).version).To(Equal(client.version))
Expect(cl.session).ToNot(BeNil())
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
Expect(cl.session.(*session).version).To(Equal(cl.version))
err = client.Close(nil)
err = cl.Close(nil)
Expect(err).ToNot(HaveOccurred())
})
It("opens a stream", func() {
stream, err := client.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(stream).ToNot(BeNil())
})
Context("handling packets", func() {
It("errors on too large packets", func() {
err := client.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1)))
err := cl.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1)))
Expect(err).To(MatchError(qerr.PacketTooLarge))
})
@ -130,7 +131,7 @@ var _ = Describe("Client", func() {
Expect(sess.packetCount).To(BeZero())
var stoppedListening bool
go func() {
_ = client.Listen()
_ = cl.Listen()
// it should continue listening when receiving valid packets
stoppedListening = true
}()
@ -142,7 +143,7 @@ var _ = Describe("Client", func() {
It("closes the session when encountering an error while handling a packet", func() {
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
listenErr := client.Listen()
listenErr := cl.Listen()
Expect(listenErr).To(HaveOccurred())
Expect(sess.closed).To(BeTrue())
Expect(sess.closeReason).To(MatchError(listenErr))
@ -160,7 +161,7 @@ var _ = Describe("Client", func() {
b.Write(s)
}
protocol.SupportedVersionsAsTags = b.Bytes()
packet := composeVersionNegotiation(client.connectionID)
packet := composeVersionNegotiation(cl.connectionID)
protocol.SupportedVersionsAsTags = oldSupportVersionTags
Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket))
return packet
@ -175,51 +176,51 @@ var _ = Describe("Client", func() {
b := &bytes.Buffer{}
err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
err = client.handlePacket(nil, b.Bytes())
err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
Expect(cl.versionNegotiated).To(BeTrue())
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
})
It("changes the version after receiving a version negotiation packet", func() {
newVersion := protocol.Version35
Expect(newVersion).ToNot(Equal(client.version))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero())
client.connectionID = 0x1337
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
Expect(client.version).To(Equal(newVersion))
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
cl.connectionID = 0x1337
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
Expect(cl.version).To(Equal(newVersion))
Expect(cl.versionNegotiated).To(BeTrue())
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
// it swapped the sessions
Expect(client.session).ToNot(Equal(sess))
Expect(client.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
Expect(cl.session).ToNot(Equal(sess))
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
Expect(err).ToNot(HaveOccurred())
// it didn't pass the version negoation packet to the session (since it has no payload)
Expect(sess.packetCount).To(BeZero())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
err = client.Close(nil)
err = cl.Close(nil)
Expect(err).ToNot(HaveOccurred())
})
It("errors if no matching version is found", func() {
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
client.versionNegotiated = true
cl.versionNegotiated = true
Expect(sess.packetCount).To(BeZero())
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(sess.packetCount).To(BeZero())
Expect(versionNegotiateCallbackCalled).To(BeFalse())
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
})
It("errors if the server should have accepted the offered version", func() {
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{client.version}))
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{cl.version}))
Expect(err).To(MatchError(qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")))
})
})

View file

@ -20,23 +20,19 @@ import (
"github.com/lucas-clemente/quic-go/utils"
)
type quicClient interface {
OpenStream() (quic.Stream, error)
Close(error) error
Listen() error
}
// Client is a HTTP2 client doing QUIC requests
type Client struct {
mutex sync.RWMutex
cryptoChangedCond sync.Cond
config *quic.Config
t *QuicRoundTripper
hostname string
encryptionLevel protocol.EncryptionLevel
client quicClient
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
requestWriter *requestWriter
@ -47,42 +43,50 @@ type Client struct {
var _ h2quicClient = &Client{}
// NewClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client {
c := &Client{
t: t,
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
}
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
var err error
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
if err != nil {
return nil, err
c.config = &quic.Config{
ConnState: c.connStateCallback,
}
go c.client.Listen()
return c, nil
return c
}
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
c.cryptoChangedCond.L.Lock()
defer c.cryptoChangedCond.L.Unlock()
// Dial dials the connection
func (c *Client) Dial() error {
_, err := quic.DialAddr(c.hostname, c.config)
return err
}
if isForwardSecure {
c.encryptionLevel = protocol.EncryptionForwardSecure
utils.Debugf("is forward secure")
} else {
func (c *Client) connStateCallback(sess quic.Session, state quic.ConnState) {
c.mutex.Lock()
if c.session == nil {
c.session = sess
}
switch state {
case quic.ConnStateVersionNegotiated:
// TODO: handle errors
c.versionNegotiateCallback()
case quic.ConnStateSecure:
c.encryptionLevel = protocol.EncryptionSecure
utils.Debugf("is secure")
c.cryptoChangedCond.Broadcast()
case quic.ConnStateForwardSecure:
c.encryptionLevel = protocol.EncryptionForwardSecure
utils.Debugf("is forward secure")
c.cryptoChangedCond.Broadcast()
}
c.cryptoChangedCond.Broadcast()
c.mutex.Unlock()
}
func (c *Client) versionNegotiateCallback() error {
var err error
// once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream()
c.headerStream, err = c.session.OpenStream()
if err != nil {
return err
}
@ -162,7 +166,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
}
hdrChan := make(chan *http.Response)
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream()
dataStream, err := c.session.OpenStream()
if err != nil {
c.Close(err)
return nil, err
@ -260,7 +264,7 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
// Close closes the client
func (c *Client) Close(e error) {
_ = c.client.Close(e)
_ = c.session.Close(e)
}
// copied from net/transport.go

View file

@ -4,6 +4,7 @@ import (
"bytes"
"compress/gzip"
"errors"
"net"
"net/http"
"golang.org/x/net/http2"
@ -17,85 +18,73 @@ import (
. "github.com/onsi/gomega"
)
type mockQuicClient struct {
nextStream protocol.StreamID
streams map[protocol.StreamID]*mockStream
closeErr error
}
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
func (m *mockQuicClient) Listen() error { panic("not implemented") }
func (m *mockQuicClient) OpenStream() (quic.Stream, error) {
id := m.nextStream
ms := &mockStream{id: id}
m.streams[id] = ms
m.nextStream += 2
return ms, nil
}
func newMockQuicClient() *mockQuicClient {
return &mockQuicClient{
streams: make(map[protocol.StreamID]*mockStream),
nextStream: 5,
}
}
var _ quicClient = &mockQuicClient{}
var _ = Describe("Client", func() {
var (
client *Client
qClient *mockQuicClient
session *mockSession
headerStream *mockStream
quicTransport *QuicRoundTripper
)
BeforeEach(func() {
var err error
quicTransport = &QuicRoundTripper{}
hostname := "quic.clemente.io:1337"
client, err = NewClient(quicTransport, nil, hostname)
Expect(err).ToNot(HaveOccurred())
client = NewClient(quicTransport, nil, hostname)
Expect(client.hostname).To(Equal(hostname))
qClient = newMockQuicClient()
client.client = qClient
session = &mockSession{}
client.session = session
headerStream = &mockStream{}
qClient.streams[3] = headerStream
headerStream = &mockStream{id: 3}
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream)
})
It("adds the port to the hostname, if none is given", func() {
var err error
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
client = NewClient(quicTransport, nil, "quic.clemente.io")
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func() {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
client = NewClient(quicTransport, nil, udpConn.LocalAddr().String())
go client.Dial()
data := make([]byte, 100)
_, err = udpConn.Read(data)
hdr, err := quic.ParsePublicHeader(bytes.NewReader(data), protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.VersionFlag).To(BeTrue())
Expect(hdr.ConnectionID).ToNot(BeNil())
})
It("saves the session when the ConnState callback is called", func() {
client.session = nil // unset the session set in BeforeEach
client.config.ConnState(session, quic.ConnStateForwardSecure)
Expect(client.session).To(Equal(session))
})
It("opens the header stream only after the version has been negotiated", func() {
// delete the headerStream openend in the BeforeEach
client.headerStream = nil
delete(qClient.streams, 3)
qClient.nextStream = 3
session.streamToOpen = headerStream
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
// now start the actual test
err := client.versionNegotiateCallback()
Expect(err).ToNot(HaveOccurred())
client.config.ConnState(session, quic.ConnStateVersionNegotiated)
Expect(client.headerStream).ToNot(BeNil())
Expect(client.headerStream.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("sets the correct crypto level", func() {
Expect(client.encryptionLevel).To(Equal(protocol.Unencrypted))
client.cryptoChangeCallback(false)
client.config.ConnState(session, quic.ConnStateSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionSecure))
client.cryptoChangeCallback(true)
client.config.ConnState(session, quic.ConnStateForwardSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
@ -122,6 +111,9 @@ var _ = Describe("Client", func() {
client.encryptionLevel = protocol.EncryptionForwardSecure
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
dataStream = &mockStream{id: 5}
session.streamToOpen = dataStream
})
It("does a request", func(done Done) {
@ -134,7 +126,6 @@ var _ = Describe("Client", func() {
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
rsp := &http.Response{
Status: "418 I'm a teapot",
@ -144,7 +135,7 @@ var _ = Describe("Client", func() {
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(doErr).ToNot(HaveOccurred())
Expect(doRsp).To(Equal(rsp))
Expect(doRsp.Body).ToNot(BeNil())
Expect(doRsp.Body).To(Equal(dataStream))
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
Expect(doRsp.Request).To(Equal(request))
close(done)
@ -172,7 +163,7 @@ var _ = Describe("Client", func() {
Expect(client.headerErr).To(HaveOccurred())
Expect(doErr).To(MatchError(client.headerErr))
Expect(doRsp).To(BeNil())
Expect(client.client.(*mockQuicClient).closeErr).To(MatchError(client.headerErr))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
Context("validating the address", func() {
@ -192,8 +183,7 @@ var _ = Describe("Client", func() {
It("adds the port for request URLs without one", func(done Done) {
var err error
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
client = NewClient(quicTransport, nil, "quic.clemente.io")
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
@ -251,7 +241,6 @@ var _ = Describe("Client", func() {
}()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
client.responses[5] <- response
dataStream := qClient.streams[5]
Eventually(func() bool { return doReturned }).Should(BeTrue())
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
@ -317,7 +306,7 @@ var _ = Describe("Client", func() {
go func() { doRsp, doErr = client.Do(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
qClient.streams[5].dataToRead.Write(gzippedData)
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
@ -350,7 +339,7 @@ var _ = Describe("Client", func() {
go func() { doRsp, doErr = client.Do(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
qClient.streams[5].dataToRead.Write([]byte("not gzipped"))
dataStream.dataToRead.Write([]byte("not gzipped"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())
@ -369,7 +358,7 @@ var _ = Describe("Client", func() {
go func() { doRsp, doErr = client.Do(request) }()
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
qClient.streams[5].dataToRead.Write([]byte("gzipped data"))
dataStream.dataToRead.Write([]byte("gzipped data"))
client.responses[5] <- response
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
Expect(doErr).ToNot(HaveOccurred())

View file

@ -12,6 +12,7 @@ import (
)
type h2quicClient interface {
Dial() error
Do(*http.Request) (*http.Response, error)
}
@ -92,8 +93,8 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
client, ok := r.clients[hostname]
if !ok {
var err error
client, err = NewClient(r, r.TLSClientConfig, hostname)
client = NewClient(r, r.TLSClientConfig, hostname)
err := client.Dial()
if err != nil {
return nil, err
}

View file

@ -11,6 +11,9 @@ import (
type mockQuicRoundTripper struct{}
func (m *mockQuicRoundTripper) Dial() error {
return nil
}
func (m *mockQuicRoundTripper) Do(req *http.Request) (*http.Response, error) {
return &http.Response{Request: req}, nil
}

View file

@ -27,6 +27,7 @@ type mockSession struct {
closedWithError error
dataStream quic.Stream
streamToAccept quic.Stream
streamToOpen quic.Stream
}
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
@ -36,7 +37,7 @@ func (s *mockSession) AcceptStream() (quic.Stream, error) {
return s.streamToAccept, nil
}
func (s *mockSession) OpenStream() (quic.Stream, error) {
panic("not implemented")
return s.streamToOpen, nil
}
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
panic("not implemented")