From 4109df32fffe3704283a46f376e50027d791ce6c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 27 Feb 2020 18:06:25 +0700 Subject: [PATCH] copy the GetLogWriter callback when populating a quic.Config --- client.go | 10 ----- config.go | 11 +++++ config_test.go | 106 ++++++++++++++++++++++++++++++++++++------------- 3 files changed, 90 insertions(+), 37 deletions(-) diff --git a/client.go b/client.go index c1bc278d..0f9d4ac7 100644 --- a/client.go +++ b/client.go @@ -248,16 +248,6 @@ func newClient( return c, nil } -// populateClientConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - config = populateConfig(config) - if config.ConnectionIDLength == 0 && !createdPacketConn { - config.ConnectionIDLength = protocol.DefaultConnectionIDLength - } - return config -} - func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) if qlogger != nil { diff --git a/config.go b/config.go index 5d05f334..729b7779 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,16 @@ func populateServerConfig(config *Config) *Config { return config } +// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateClientConfig(config *Config, createdPacketConn bool) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 && !createdPacketConn { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + return config +} + func populateConfig(config *Config) *Config { if config == nil { config = &Config{} @@ -72,5 +82,6 @@ func populateConfig(config *Config) *Config { StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, QuicTracer: config.QuicTracer, + GetLogWriter: config.GetLogWriter, } } diff --git a/config_test.go b/config_test.go index e1c86dbf..10a4423a 100644 --- a/config_test.go +++ b/config_test.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/quictrace" . "github.com/onsi/ginkgo" @@ -14,22 +15,9 @@ import ( ) var _ = Describe("Config", func() { - It("clones function fields", func() { - var calledAcceptToken, calledGetLogWriter bool - c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil }, - } - c2 := c1.Clone() - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - c2.GetLogWriter([]byte{1, 2, 3}) - Expect(calledAcceptToken).To(BeTrue()) - Expect(calledGetLogWriter).To(BeTrue()) - }) - - It("clones non-function fields", func() { - c1 := &Config{} - v := reflect.ValueOf(c1).Elem() + configWithNonZeroNonFunctionFields := func() *Config { + c := &Config{} + v := reflect.ValueOf(c).Elem() typ := v.Type() for i := 0; i < typ.NumField(); i++ { @@ -70,20 +58,84 @@ var _ = Describe("Config", func() { Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) } } + return c + } + Context("cloning", func() { + It("clones function fields", func() { + var calledAcceptToken, calledGetLogWriter bool + c1 := &Config{ + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, + GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil }, + } + c2 := c1.Clone() + c2.AcceptToken(&net.UDPAddr{}, &Token{}) + c2.GetLogWriter([]byte{1, 2, 3}) + Expect(calledAcceptToken).To(BeTrue()) + Expect(calledGetLogWriter).To(BeTrue()) + }) - Expect(c1.Clone()).To(Equal(c1)) + It("clones non-function fields", func() { + c := configWithNonZeroNonFunctionFields() + Expect(c.Clone()).To(Equal(c)) + }) + + It("returns a copy", func() { + c1 := &Config{ + MaxIncomingStreams: 100, + AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, + } + c2 := c1.Clone() + c2.MaxIncomingStreams = 200 + c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + + Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) + Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + }) }) - It("returns a copy", func() { - c1 := &Config{ - MaxIncomingStreams: 100, - AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, - } - c2 := c1.Clone() - c2.MaxIncomingStreams = 200 - c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + Context("populating", func() { + It("populates function fields", func() { + var calledAcceptToken, calledGetLogWriter bool + c1 := &Config{ + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, + GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil }, + } + c2 := populateConfig(c1) + c2.AcceptToken(&net.UDPAddr{}, &Token{}) + c2.GetLogWriter([]byte{1, 2, 3}) + Expect(calledAcceptToken).To(BeTrue()) + Expect(calledGetLogWriter).To(BeTrue()) + }) - Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + It("copies non-function fields", func() { + c := configWithNonZeroNonFunctionFields() + Expect(populateConfig(c)).To(Equal(c)) + }) + + It("populates empty fields with default values", func() { + c := populateConfig(&Config{}) + Expect(c.Versions).To(Equal(protocol.SupportedVersions)) + Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout)) + Expect(c.MaxReceiveStreamFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) + Expect(c.MaxReceiveConnectionFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) + Expect(c.MaxIncomingStreams).To(Equal(protocol.DefaultMaxIncomingStreams)) + Expect(c.MaxIncomingUniStreams).To(Equal(protocol.DefaultMaxIncomingUniStreams)) + }) + + It("populates empty fields with default values, for the server", func() { + c := populateServerConfig(&Config{}) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + Expect(c.AcceptToken).ToNot(BeNil()) + }) + + It("sets a default connection ID length if we didn't create the conn, for the client", func() { + c := populateClientConfig(&Config{}, false) + Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) + }) + + It("doesn't set a default connection ID length if we created the conn, for the client", func() { + c := populateClientConfig(&Config{}, true) + Expect(c.ConnectionIDLength).To(BeZero()) + }) }) })