mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
privatize the client, only expose Dial functions
This commit is contained in:
parent
48dee2708e
commit
96edca5219
7 changed files with 214 additions and 188 deletions
121
client.go
121
client.go
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -14,45 +15,35 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
// A Client of QUIC
|
||||
type Client struct {
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
connStateChangeCond sync.Cond
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
config *Config
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
versionNegotiated bool
|
||||
closed uint32 // atomic bool
|
||||
|
||||
tlsConfig *tls.Config
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
versionNegotiateCallback VersionNegotiateCallback
|
||||
tlsConfig *tls.Config
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
|
||||
session packetHandler
|
||||
}
|
||||
|
||||
// VersionNegotiateCallback is called once the client has a negotiated version
|
||||
type VersionNegotiateCallback func() error
|
||||
|
||||
var errHostname = errors.New("Invalid hostname")
|
||||
|
||||
var (
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
)
|
||||
|
||||
// NewClient makes a new client
|
||||
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectionID, err := utils.GenerateConnectionID()
|
||||
// Dial establishes a new QUIC connection to a server
|
||||
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -62,28 +53,65 @@ func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoCh
|
|||
return nil, err
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
conn: &conn{pconn: udpConn, currentAddr: udpAddr},
|
||||
hostname: hostname,
|
||||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
connectionID: connectionID,
|
||||
tlsConfig: tlsConfig,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
versionNegotiateCallback: versionNegotiateCallback,
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
config: config,
|
||||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version)
|
||||
c.connStateChangeCond.L = &c.mutex
|
||||
|
||||
err = client.createNewSession(nil)
|
||||
c.cryptoChangeCallback = func(isForwardSecure bool) {
|
||||
var state ConnState
|
||||
if isForwardSecure {
|
||||
state = ConnStateForwardSecure
|
||||
} else {
|
||||
state = ConnStateSecure
|
||||
}
|
||||
|
||||
if c.config.ConnState != nil {
|
||||
go config.ConnState(c.session, state)
|
||||
}
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
// TODO: handle errors
|
||||
go c.Listen()
|
||||
|
||||
c.mutex.Lock()
|
||||
for !c.versionNegotiated {
|
||||
c.connStateChangeCond.Wait()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server
|
||||
func DialAddr(hostname string, config *Config) (Session, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Dial(udpConn, udpAddr, hostname, config)
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
func (c *Client) Listen() error {
|
||||
func (c *client) Listen() error {
|
||||
for {
|
||||
data := getPacketBuffer()
|
||||
data = data[:protocol.MaxPacketSize]
|
||||
|
@ -106,13 +134,8 @@ func (c *Client) Listen() error {
|
|||
}
|
||||
}
|
||||
|
||||
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
|
||||
func (c *Client) OpenStream() (Stream, error) {
|
||||
return c.session.OpenStream()
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Client) Close(e error) error {
|
||||
func (c *client) Close(e error) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
|
||||
return nil
|
||||
|
@ -122,7 +145,7 @@ func (c *Client) Close(e error) error {
|
|||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
|
||||
return qerr.PacketTooLarge
|
||||
}
|
||||
|
@ -145,10 +168,12 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
|||
// this is the first packet after the client sent a packet with the VersionFlag set
|
||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
||||
c.mutex.Lock()
|
||||
c.versionNegotiated = true
|
||||
err = c.versionNegotiateCallback()
|
||||
if err != nil {
|
||||
return err
|
||||
c.connStateChangeCond.Signal()
|
||||
c.mutex.Unlock()
|
||||
if c.config.ConnState != nil {
|
||||
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,7 +212,9 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.versionNegotiateCallback()
|
||||
if c.config.ConnState != nil {
|
||||
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -204,14 +231,14 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.tlsConfig,
|
||||
c.config.TLSConfig,
|
||||
c.closeCallback,
|
||||
c.cryptoChangeCallback,
|
||||
negotiatedVersions)
|
||||
|
@ -223,6 +250,6 @@ func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) closeCallback(id protocol.ConnectionID) {
|
||||
func (c *client) closeCallback(id protocol.ConnectionID) {
|
||||
utils.Infof("Connection %x closed.", id)
|
||||
}
|
||||
|
|
121
client_test.go
121
client_test.go
|
@ -19,44 +19,51 @@ import (
|
|||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
client *Client
|
||||
sess *mockSession
|
||||
packetConn *mockPacketConn
|
||||
versionNegotiateCallbackCalled bool
|
||||
cl *client
|
||||
config *Config
|
||||
sess *mockSession
|
||||
packetConn *mockPacketConn
|
||||
addr net.Addr
|
||||
versionNegotiateConnStateCalled bool
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
versionNegotiateConnStateCalled = false
|
||||
packetConn = &mockPacketConn{}
|
||||
versionNegotiateCallbackCalled = false
|
||||
client = &Client{
|
||||
versionNegotiateCallback: func() error {
|
||||
versionNegotiateCallbackCalled = true
|
||||
return nil
|
||||
config = &Config{
|
||||
ConnState: func(_ Session, state ConnState) {
|
||||
if state == ConnStateVersionNegotiated {
|
||||
versionNegotiateConnStateCalled = true
|
||||
}
|
||||
},
|
||||
}
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
sess = &mockSession{connectionID: 0x1337}
|
||||
client.connectionID = 0x1337
|
||||
client.session = sess
|
||||
client.version = protocol.Version36
|
||||
client.conn = &conn{pconn: packetConn, currentAddr: addr}
|
||||
cl = &client{
|
||||
config: config,
|
||||
connectionID: 0x1337,
|
||||
session: sess,
|
||||
version: protocol.Version36,
|
||||
conn: &conn{pconn: packetConn, currentAddr: addr},
|
||||
}
|
||||
})
|
||||
|
||||
It("creates a new client", func() {
|
||||
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io:1337", nil, nil, nil)
|
||||
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io"))
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
||||
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
||||
})
|
||||
|
||||
It("errors on invalid public header", func() {
|
||||
err := client.handlePacket(nil, nil)
|
||||
err := cl.handlePacket(nil, nil)
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
||||
})
|
||||
|
||||
It("errors on large packets", func() {
|
||||
err := client.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1))
|
||||
err := cl.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1))
|
||||
Expect(err).To(MatchError(qerr.PacketTooLarge))
|
||||
})
|
||||
|
||||
|
@ -68,51 +75,45 @@ var _ = Describe("Client", func() {
|
|||
var stoppedListening bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := client.Listen()
|
||||
err := cl.Listen()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
stoppedListening = true
|
||||
}()
|
||||
|
||||
err := client.Close(testErr)
|
||||
err := cl.Close(testErr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(sess.closed).Should(BeTrue())
|
||||
Expect(sess.closeReason).To(MatchError(testErr))
|
||||
Expect(client.closed).To(Equal(uint32(1)))
|
||||
Expect(cl.closed).To(Equal(uint32(1)))
|
||||
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
|
||||
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
|
||||
close(done)
|
||||
}, 10)
|
||||
|
||||
It("only closes the client once", func() {
|
||||
client.closed = 1
|
||||
err := client.Close(errors.New("test error"))
|
||||
cl.closed = 1
|
||||
err := cl.Close(errors.New("test error"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(sess.closed).Should(BeFalse())
|
||||
Expect(sess.closeReason).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("creates new sessions with the right parameters", func() {
|
||||
client.session = nil
|
||||
client.hostname = "hostname"
|
||||
err := client.createNewSession(nil)
|
||||
cl.session = nil
|
||||
cl.hostname = "hostname"
|
||||
err := cl.createNewSession(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.session).ToNot(BeNil())
|
||||
Expect(client.session.(*session).connectionID).To(Equal(client.connectionID))
|
||||
Expect(client.session.(*session).version).To(Equal(client.version))
|
||||
Expect(cl.session).ToNot(BeNil())
|
||||
Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID))
|
||||
Expect(cl.session.(*session).version).To(Equal(cl.version))
|
||||
|
||||
err = client.Close(nil)
|
||||
err = cl.Close(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("opens a stream", func() {
|
||||
stream, err := client.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream).ToNot(BeNil())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
It("errors on too large packets", func() {
|
||||
err := client.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1)))
|
||||
err := cl.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1)))
|
||||
Expect(err).To(MatchError(qerr.PacketTooLarge))
|
||||
})
|
||||
|
||||
|
@ -130,7 +131,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(sess.packetCount).To(BeZero())
|
||||
var stoppedListening bool
|
||||
go func() {
|
||||
_ = client.Listen()
|
||||
_ = cl.Listen()
|
||||
// it should continue listening when receiving valid packets
|
||||
stoppedListening = true
|
||||
}()
|
||||
|
@ -142,7 +143,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("closes the session when encountering an error while handling a packet", func() {
|
||||
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
|
||||
listenErr := client.Listen()
|
||||
listenErr := cl.Listen()
|
||||
Expect(listenErr).To(HaveOccurred())
|
||||
Expect(sess.closed).To(BeTrue())
|
||||
Expect(sess.closeReason).To(MatchError(listenErr))
|
||||
|
@ -160,7 +161,7 @@ var _ = Describe("Client", func() {
|
|||
b.Write(s)
|
||||
}
|
||||
protocol.SupportedVersionsAsTags = b.Bytes()
|
||||
packet := composeVersionNegotiation(client.connectionID)
|
||||
packet := composeVersionNegotiation(cl.connectionID)
|
||||
protocol.SupportedVersionsAsTags = oldSupportVersionTags
|
||||
Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket))
|
||||
return packet
|
||||
|
@ -175,51 +176,51 @@ var _ = Describe("Client", func() {
|
|||
b := &bytes.Buffer{}
|
||||
err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = client.handlePacket(nil, b.Bytes())
|
||||
err = cl.handlePacket(nil, b.Bytes())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.versionNegotiated).To(BeTrue())
|
||||
Expect(versionNegotiateCallbackCalled).To(BeTrue())
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("changes the version after receiving a version negotiation packet", func() {
|
||||
newVersion := protocol.Version35
|
||||
Expect(newVersion).ToNot(Equal(client.version))
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
client.connectionID = 0x1337
|
||||
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
|
||||
Expect(client.version).To(Equal(newVersion))
|
||||
Expect(client.versionNegotiated).To(BeTrue())
|
||||
Expect(versionNegotiateCallbackCalled).To(BeTrue())
|
||||
cl.connectionID = 0x1337
|
||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
|
||||
Expect(cl.version).To(Equal(newVersion))
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
||||
// it swapped the sessions
|
||||
Expect(client.session).ToNot(Equal(sess))
|
||||
Expect(client.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
|
||||
Expect(cl.session).ToNot(Equal(sess))
|
||||
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// it didn't pass the version negoation packet to the session (since it has no payload)
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
|
||||
|
||||
err = client.Close(nil)
|
||||
err = cl.Close(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
||||
Expect(err).To(MatchError(qerr.InvalidVersion))
|
||||
})
|
||||
|
||||
It("ignores delayed version negotiation packets", func() {
|
||||
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
|
||||
client.versionNegotiated = true
|
||||
cl.versionNegotiated = true
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.versionNegotiated).To(BeTrue())
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
Expect(versionNegotiateCallbackCalled).To(BeFalse())
|
||||
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("errors if the server should have accepted the offered version", func() {
|
||||
err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{client.version}))
|
||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{cl.version}))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -20,23 +20,19 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type quicClient interface {
|
||||
OpenStream() (quic.Stream, error)
|
||||
Close(error) error
|
||||
Listen() error
|
||||
}
|
||||
|
||||
// Client is a HTTP2 client doing QUIC requests
|
||||
type Client struct {
|
||||
mutex sync.RWMutex
|
||||
cryptoChangedCond sync.Cond
|
||||
|
||||
config *quic.Config
|
||||
|
||||
t *QuicRoundTripper
|
||||
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
|
||||
client quicClient
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
headerErr *qerr.QuicError
|
||||
requestWriter *requestWriter
|
||||
|
@ -47,42 +43,50 @@ type Client struct {
|
|||
var _ h2quicClient = &Client{}
|
||||
|
||||
// NewClient creates a new client
|
||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client {
|
||||
c := &Client{
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
}
|
||||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||
|
||||
var err error
|
||||
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.config = &quic.Config{
|
||||
ConnState: c.connStateCallback,
|
||||
}
|
||||
|
||||
go c.client.Listen()
|
||||
return c, nil
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
|
||||
c.cryptoChangedCond.L.Lock()
|
||||
defer c.cryptoChangedCond.L.Unlock()
|
||||
// Dial dials the connection
|
||||
func (c *Client) Dial() error {
|
||||
_, err := quic.DialAddr(c.hostname, c.config)
|
||||
return err
|
||||
}
|
||||
|
||||
if isForwardSecure {
|
||||
c.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
utils.Debugf("is forward secure")
|
||||
} else {
|
||||
func (c *Client) connStateCallback(sess quic.Session, state quic.ConnState) {
|
||||
c.mutex.Lock()
|
||||
if c.session == nil {
|
||||
c.session = sess
|
||||
}
|
||||
switch state {
|
||||
case quic.ConnStateVersionNegotiated:
|
||||
// TODO: handle errors
|
||||
c.versionNegotiateCallback()
|
||||
case quic.ConnStateSecure:
|
||||
c.encryptionLevel = protocol.EncryptionSecure
|
||||
utils.Debugf("is secure")
|
||||
c.cryptoChangedCond.Broadcast()
|
||||
case quic.ConnStateForwardSecure:
|
||||
c.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
utils.Debugf("is forward secure")
|
||||
c.cryptoChangedCond.Broadcast()
|
||||
}
|
||||
c.cryptoChangedCond.Broadcast()
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) versionNegotiateCallback() error {
|
||||
var err error
|
||||
// once the version has been negotiated, open the header stream
|
||||
c.headerStream, err = c.client.OpenStream()
|
||||
c.headerStream, err = c.session.OpenStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -162,7 +166,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
hdrChan := make(chan *http.Response)
|
||||
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
||||
dataStream, err := c.client.OpenStream()
|
||||
dataStream, err := c.session.OpenStream()
|
||||
if err != nil {
|
||||
c.Close(err)
|
||||
return nil, err
|
||||
|
@ -260,7 +264,7 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
|||
|
||||
// Close closes the client
|
||||
func (c *Client) Close(e error) {
|
||||
_ = c.client.Close(e)
|
||||
_ = c.session.Close(e)
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
@ -17,85 +18,73 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockQuicClient struct {
|
||||
nextStream protocol.StreamID
|
||||
streams map[protocol.StreamID]*mockStream
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
|
||||
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
||||
func (m *mockQuicClient) OpenStream() (quic.Stream, error) {
|
||||
id := m.nextStream
|
||||
ms := &mockStream{id: id}
|
||||
m.streams[id] = ms
|
||||
m.nextStream += 2
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func newMockQuicClient() *mockQuicClient {
|
||||
return &mockQuicClient{
|
||||
streams: make(map[protocol.StreamID]*mockStream),
|
||||
nextStream: 5,
|
||||
}
|
||||
}
|
||||
|
||||
var _ quicClient = &mockQuicClient{}
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
client *Client
|
||||
qClient *mockQuicClient
|
||||
session *mockSession
|
||||
headerStream *mockStream
|
||||
quicTransport *QuicRoundTripper
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
quicTransport = &QuicRoundTripper{}
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client, err = NewClient(quicTransport, nil, hostname)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client = NewClient(quicTransport, nil, hostname)
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
client.client = qClient
|
||||
session = &mockSession{}
|
||||
client.session = session
|
||||
|
||||
headerStream = &mockStream{}
|
||||
qClient.streams[3] = headerStream
|
||||
headerStream = &mockStream{id: 3}
|
||||
client.headerStream = headerStream
|
||||
client.requestWriter = newRequestWriter(headerStream)
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
||||
It("dials", func() {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client = NewClient(quicTransport, nil, udpConn.LocalAddr().String())
|
||||
go client.Dial()
|
||||
data := make([]byte, 100)
|
||||
_, err = udpConn.Read(data)
|
||||
hdr, err := quic.ParsePublicHeader(bytes.NewReader(data), protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.VersionFlag).To(BeTrue())
|
||||
Expect(hdr.ConnectionID).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("saves the session when the ConnState callback is called", func() {
|
||||
client.session = nil // unset the session set in BeforeEach
|
||||
client.config.ConnState(session, quic.ConnStateForwardSecure)
|
||||
Expect(client.session).To(Equal(session))
|
||||
})
|
||||
|
||||
It("opens the header stream only after the version has been negotiated", func() {
|
||||
// delete the headerStream openend in the BeforeEach
|
||||
client.headerStream = nil
|
||||
delete(qClient.streams, 3)
|
||||
qClient.nextStream = 3
|
||||
session.streamToOpen = headerStream
|
||||
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
||||
// now start the actual test
|
||||
err := client.versionNegotiateCallback()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client.config.ConnState(session, quic.ConnStateVersionNegotiated)
|
||||
Expect(client.headerStream).ToNot(BeNil())
|
||||
Expect(client.headerStream.StreamID()).To(Equal(protocol.StreamID(3)))
|
||||
})
|
||||
|
||||
It("sets the correct crypto level", func() {
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.Unencrypted))
|
||||
client.cryptoChangeCallback(false)
|
||||
client.config.ConnState(session, quic.ConnStateSecure)
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
client.cryptoChangeCallback(true)
|
||||
client.config.ConnState(session, quic.ConnStateForwardSecure)
|
||||
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
|
||||
})
|
||||
|
||||
Context("Doing requests", func() {
|
||||
var request *http.Request
|
||||
var dataStream *mockStream
|
||||
|
||||
getRequest := func(data []byte) *http2.MetaHeadersFrame {
|
||||
r := bytes.NewReader(data)
|
||||
|
@ -122,6 +111,9 @@ var _ = Describe("Client", func() {
|
|||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dataStream = &mockStream{id: 5}
|
||||
session.streamToOpen = dataStream
|
||||
})
|
||||
|
||||
It("does a request", func(done Done) {
|
||||
|
@ -134,7 +126,6 @@ var _ = Describe("Client", func() {
|
|||
}()
|
||||
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
||||
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
|
||||
rsp := &http.Response{
|
||||
Status: "418 I'm a teapot",
|
||||
|
@ -144,7 +135,7 @@ var _ = Describe("Client", func() {
|
|||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
Expect(doRsp).To(Equal(rsp))
|
||||
Expect(doRsp.Body).ToNot(BeNil())
|
||||
Expect(doRsp.Body).To(Equal(dataStream))
|
||||
Expect(doRsp.ContentLength).To(BeEquivalentTo(-1))
|
||||
Expect(doRsp.Request).To(Equal(request))
|
||||
close(done)
|
||||
|
@ -172,7 +163,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(client.headerErr).To(HaveOccurred())
|
||||
Expect(doErr).To(MatchError(client.headerErr))
|
||||
Expect(doRsp).To(BeNil())
|
||||
Expect(client.client.(*mockQuicClient).closeErr).To(MatchError(client.headerErr))
|
||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
||||
})
|
||||
|
||||
Context("validating the address", func() {
|
||||
|
@ -192,8 +183,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
|
@ -251,7 +241,6 @@ var _ = Describe("Client", func() {
|
|||
}()
|
||||
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
||||
client.responses[5] <- response
|
||||
dataStream := qClient.streams[5]
|
||||
Eventually(func() bool { return doReturned }).Should(BeTrue())
|
||||
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
|
||||
Expect(dataStream.closed).To(BeTrue())
|
||||
|
@ -317,7 +306,7 @@ var _ = Describe("Client", func() {
|
|||
go func() { doRsp, doErr = client.Do(request) }()
|
||||
|
||||
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
||||
qClient.streams[5].dataToRead.Write(gzippedData)
|
||||
dataStream.dataToRead.Write(gzippedData)
|
||||
response.Header.Add("Content-Encoding", "gzip")
|
||||
client.responses[5] <- response
|
||||
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
||||
|
@ -350,7 +339,7 @@ var _ = Describe("Client", func() {
|
|||
go func() { doRsp, doErr = client.Do(request) }()
|
||||
|
||||
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
||||
qClient.streams[5].dataToRead.Write([]byte("not gzipped"))
|
||||
dataStream.dataToRead.Write([]byte("not gzipped"))
|
||||
client.responses[5] <- response
|
||||
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
|
@ -369,7 +358,7 @@ var _ = Describe("Client", func() {
|
|||
go func() { doRsp, doErr = client.Do(request) }()
|
||||
|
||||
Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil())
|
||||
qClient.streams[5].dataToRead.Write([]byte("gzipped data"))
|
||||
dataStream.dataToRead.Write([]byte("gzipped data"))
|
||||
client.responses[5] <- response
|
||||
Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil())
|
||||
Expect(doErr).ToNot(HaveOccurred())
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
type h2quicClient interface {
|
||||
Dial() error
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
|
@ -92,8 +93,8 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
|||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
var err error
|
||||
client, err = NewClient(r, r.TLSClientConfig, hostname)
|
||||
client = NewClient(r, r.TLSClientConfig, hostname)
|
||||
err := client.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
|
||||
type mockQuicRoundTripper struct{}
|
||||
|
||||
func (m *mockQuicRoundTripper) Dial() error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockQuicRoundTripper) Do(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ type mockSession struct {
|
|||
closedWithError error
|
||||
dataStream quic.Stream
|
||||
streamToAccept quic.Stream
|
||||
streamToOpen quic.Stream
|
||||
}
|
||||
|
||||
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
||||
|
@ -36,7 +37,7 @@ func (s *mockSession) AcceptStream() (quic.Stream, error) {
|
|||
return s.streamToAccept, nil
|
||||
}
|
||||
func (s *mockSession) OpenStream() (quic.Stream, error) {
|
||||
panic("not implemented")
|
||||
return s.streamToOpen, nil
|
||||
}
|
||||
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
|
||||
panic("not implemented")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue