mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
parent
e9666c6313
commit
268841f0cc
10 changed files with 28 additions and 14 deletions
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
@ -24,6 +25,7 @@ type Client struct {
|
|||
versionNegotiated bool
|
||||
closed uint32 // atomic bool
|
||||
|
||||
tlsConfig *tls.Config
|
||||
cryptoChangeCallback CryptoChangeCallback
|
||||
versionNegotiateCallback VersionNegotiateCallback
|
||||
|
||||
|
@ -40,7 +42,7 @@ var (
|
|||
)
|
||||
|
||||
// NewClient makes a new client
|
||||
func NewClient(host string, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
|
||||
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -67,6 +69,7 @@ func NewClient(host string, cryptoChangeCallback CryptoChangeCallback, versionNe
|
|||
hostname: hostname,
|
||||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
connectionID: connectionID,
|
||||
tlsConfig: tlsConfig,
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
versionNegotiateCallback: versionNegotiateCallback,
|
||||
}
|
||||
|
@ -200,7 +203,7 @@ func (c *Client) handlePacket(packet []byte) error {
|
|||
|
||||
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions)
|
||||
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("creates a new client", func() {
|
||||
var err error
|
||||
client, err = NewClient("quic.clemente.io:1337", nil, nil)
|
||||
client, err = NewClient("quic.clemente.io:1337", nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io"))
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -47,7 +48,7 @@ type Client struct {
|
|||
var _ h2quicClient = &Client{}
|
||||
|
||||
// NewClient creates a new client
|
||||
func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) {
|
||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
||||
c := &Client{
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
|
@ -57,7 +58,7 @@ func NewClient(t *QuicRoundTripper, hostname string) (*Client, error) {
|
|||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||
|
||||
var err error
|
||||
c.client, err = quic.NewClient(c.hostname, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ var _ = Describe("Client", func() {
|
|||
var err error
|
||||
quicTransport = &QuicRoundTripper{}
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client, err = NewClient(quicTransport, hostname)
|
||||
client, err = NewClient(quicTransport, nil, hostname)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
qClient = newMockQuicClient()
|
||||
|
@ -68,7 +68,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
@ -192,7 +192,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("adds the port for request URLs without one", func(done Done) {
|
||||
var err error
|
||||
client, err = NewClient(quicTransport, "quic.clemente.io")
|
||||
client, err = NewClient(quicTransport, nil, "quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -28,6 +29,10 @@ type QuicRoundTripper struct {
|
|||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
clients map[string]h2quicClient
|
||||
}
|
||||
|
||||
|
@ -88,7 +93,7 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
|
|||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
var err error
|
||||
client, err = NewClient(r, hostname)
|
||||
client, err = NewClient(r, r.TLSClientConfig, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -57,7 +57,8 @@ var _ = Describe("RoundTripper", func() {
|
|||
It("reuses existing clients", func() {
|
||||
rt.clients = make(map[string]h2quicClient)
|
||||
rt.clients["www.example.org:443"] = &mockQuicRoundTripper{}
|
||||
rsp, _ := rt.RoundTrip(req1)
|
||||
rsp, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req1))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
|
|
@ -3,6 +3,7 @@ package handshake
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -64,6 +65,7 @@ func NewCryptoSetupClient(
|
|||
connID protocol.ConnectionID,
|
||||
version protocol.VersionNumber,
|
||||
cryptoStream utils.Stream,
|
||||
tlsConfig *tls.Config,
|
||||
connectionParameters ConnectionParametersManager,
|
||||
aeadChanged chan struct{},
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
|
@ -73,7 +75,7 @@ func NewCryptoSetupClient(
|
|||
connID: connID,
|
||||
version: version,
|
||||
cryptoStream: cryptoStream,
|
||||
certManager: crypto.NewCertManager(nil),
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
connectionParameters: connectionParameters,
|
||||
keyDerivation: crypto.DeriveKeysAESGCM,
|
||||
aeadChanged: aeadChanged,
|
||||
|
|
|
@ -122,7 +122,7 @@ var _ = Describe("Crypto setup", func() {
|
|||
stream = &mockStream{}
|
||||
certManager = &mockCertManager{}
|
||||
version := protocol.Version36
|
||||
csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil)
|
||||
csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, nil, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupClient)
|
||||
cs.certManager = certManager
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -127,7 +128,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||
return session, err
|
||||
}
|
||||
|
||||
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) {
|
||||
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) {
|
||||
session := &Session{
|
||||
conn: &udpConn{conn: conn, currentAddr: addr},
|
||||
connectionID: connectionID,
|
||||
|
@ -145,7 +146,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p
|
|||
|
||||
cryptoStream, _ := session.OpenStream(1)
|
||||
var err error
|
||||
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged, negotiatedVersions)
|
||||
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -155,6 +155,7 @@ var _ = Describe("Session", func() {
|
|||
"hostname",
|
||||
protocol.Version35,
|
||||
0,
|
||||
nil,
|
||||
func(*Session, utils.Stream) { streamCallbackCalled = true },
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(isForwardSecure bool) {},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue