mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
copy the GetLogWriter callback when populating a quic.Config
This commit is contained in:
parent
d31dcdaa7b
commit
4109df32ff
3 changed files with 90 additions and 37 deletions
10
client.go
10
client.go
|
@ -248,16 +248,6 @@ func newClient(
|
|||
return c, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.ConnectionIDLength == 0 && !createdPacketConn {
|
||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
if qlogger != nil {
|
||||
|
|
11
config.go
11
config.go
|
@ -21,6 +21,16 @@ func populateServerConfig(config *Config) *Config {
|
|||
return config
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.ConnectionIDLength == 0 && !createdPacketConn {
|
||||
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func populateConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
|
@ -72,5 +82,6 @@ func populateConfig(config *Config) *Config {
|
|||
StatelessResetKey: config.StatelessResetKey,
|
||||
TokenStore: config.TokenStore,
|
||||
QuicTracer: config.QuicTracer,
|
||||
GetLogWriter: config.GetLogWriter,
|
||||
}
|
||||
}
|
||||
|
|
106
config_test.go
106
config_test.go
|
@ -7,6 +7,7 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/quictrace"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -14,22 +15,9 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Config", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAcceptToken, calledGetLogWriter bool
|
||||
c1 := &Config{
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
||||
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
||||
c2.GetLogWriter([]byte{1, 2, 3})
|
||||
Expect(calledAcceptToken).To(BeTrue())
|
||||
Expect(calledGetLogWriter).To(BeTrue())
|
||||
})
|
||||
|
||||
It("clones non-function fields", func() {
|
||||
c1 := &Config{}
|
||||
v := reflect.ValueOf(c1).Elem()
|
||||
configWithNonZeroNonFunctionFields := func() *Config {
|
||||
c := &Config{}
|
||||
v := reflect.ValueOf(c).Elem()
|
||||
|
||||
typ := v.Type()
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
|
@ -70,20 +58,84 @@ var _ = Describe("Config", func() {
|
|||
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
Context("cloning", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAcceptToken, calledGetLogWriter bool
|
||||
c1 := &Config{
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
||||
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
||||
c2.GetLogWriter([]byte{1, 2, 3})
|
||||
Expect(calledAcceptToken).To(BeTrue())
|
||||
Expect(calledGetLogWriter).To(BeTrue())
|
||||
})
|
||||
|
||||
Expect(c1.Clone()).To(Equal(c1))
|
||||
It("clones non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(c.Clone()).To(Equal(c))
|
||||
})
|
||||
|
||||
It("returns a copy", func() {
|
||||
c1 := &Config{
|
||||
MaxIncomingStreams: 100,
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.MaxIncomingStreams = 200
|
||||
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
|
||||
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
||||
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
It("returns a copy", func() {
|
||||
c1 := &Config{
|
||||
MaxIncomingStreams: 100,
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.MaxIncomingStreams = 200
|
||||
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
|
||||
Context("populating", func() {
|
||||
It("populates function fields", func() {
|
||||
var calledAcceptToken, calledGetLogWriter bool
|
||||
c1 := &Config{
|
||||
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
|
||||
GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil },
|
||||
}
|
||||
c2 := populateConfig(c1)
|
||||
c2.AcceptToken(&net.UDPAddr{}, &Token{})
|
||||
c2.GetLogWriter([]byte{1, 2, 3})
|
||||
Expect(calledAcceptToken).To(BeTrue())
|
||||
Expect(calledGetLogWriter).To(BeTrue())
|
||||
})
|
||||
|
||||
Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
|
||||
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
|
||||
It("copies non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(populateConfig(c)).To(Equal(c))
|
||||
})
|
||||
|
||||
It("populates empty fields with default values", func() {
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
|
||||
Expect(c.MaxReceiveStreamFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
|
||||
Expect(c.MaxReceiveConnectionFlowControlWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
|
||||
Expect(c.MaxIncomingStreams).To(Equal(protocol.DefaultMaxIncomingStreams))
|
||||
Expect(c.MaxIncomingUniStreams).To(Equal(protocol.DefaultMaxIncomingUniStreams))
|
||||
})
|
||||
|
||||
It("populates empty fields with default values, for the server", func() {
|
||||
c := populateServerConfig(&Config{})
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
Expect(c.AcceptToken).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, false)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
|
||||
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue