copy the GetLogWriter callback when populating a quic.Config

This commit is contained in:
Marten Seemann 2020-02-27 18:06:25 +07:00
parent d31dcdaa7b
commit 4109df32ff
3 changed files with 90 additions and 37 deletions

View file

@ -248,16 +248,6 @@ func newClient(
return c, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 && !createdPacketConn {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
return config
}
func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
if qlogger != nil {

View file

@ -21,6 +21,16 @@ func populateServerConfig(config *Config) *Config {
return config
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 && !createdPacketConn {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
return config
}
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
@ -72,5 +82,6 @@ func populateConfig(config *Config) *Config {
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
QuicTracer: config.QuicTracer,
GetLogWriter: config.GetLogWriter,
}
}

View file

@ -7,6 +7,7 @@ import (
"reflect"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quictrace"
. "github.com/onsi/ginkgo"
@ -14,22 +15,9 @@ import (
)
var _ = Describe("Config", func() {
It("clones function fields", func() {
var calledAcceptToken, calledGetLogWriter bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
}
c2 := c1.Clone()
c2.AcceptToken(&net.UDPAddr{}, &Token{})
c2.GetLogWriter([]byte{1, 2, 3})
Expect(calledAcceptToken).To(BeTrue())
Expect(calledGetLogWriter).To(BeTrue())
})
It("clones non-function fields", func() {
c1 := &Config{}
v := reflect.ValueOf(c1).Elem()
configWithNonZeroNonFunctionFields := func() *Config {
c := &Config{}
v := reflect.ValueOf(c).Elem()
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
@ -70,20 +58,84 @@ var _ = Describe("Config", func() {
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
}
}
return c
}
Context("cloning", func() {
It("clones function fields", func() {
var calledAcceptToken, calledGetLogWriter bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
}
c2 := c1.Clone()
c2.AcceptToken(&net.UDPAddr{}, &Token{})
c2.GetLogWriter([]byte{1, 2, 3})
Expect(calledAcceptToken).To(BeTrue())
Expect(calledGetLogWriter).To(BeTrue())
})
Expect(c1.Clone()).To(Equal(c1))
It("clones non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(c.Clone()).To(Equal(c))
})
It("returns a copy", func() {
c1 := &Config{
MaxIncomingStreams: 100,
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
})
})
It("returns a copy", func() {
c1 := &Config{
MaxIncomingStreams: 100,
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
Context("populating", func() {
It("populates function fields", func() {
var calledAcceptToken, calledGetLogWriter bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
}
c2 := populateConfig(c1)
c2.AcceptToken(&net.UDPAddr{}, &Token{})
c2.GetLogWriter([]byte{1, 2, 3})
Expect(calledAcceptToken).To(BeTrue())
Expect(calledGetLogWriter).To(BeTrue())
})
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c)).To(Equal(c))
})
It("populates empty fields with default values", func() {
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
Expect(c.MaxReceiveStreamFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
Expect(c.MaxReceiveConnectionFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
Expect(c.MaxIncomingStreams).To(Equal(protocol.DefaultMaxIncomingStreams))
Expect(c.MaxIncomingUniStreams).To(Equal(protocol.DefaultMaxIncomingUniStreams))
})
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.AcceptToken).ToNot(BeNil())
})
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
c := populateClientConfig(&Config{}, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
})
})