use a smaller packetHandler interface

The packetHandler interface just needs two methods: one for handling
packets, and one for closing.
This commit is contained in:
Marten Seemann 2018-05-21 09:19:16 +08:00
parent ef34d9e85f
commit b3fd768a61
11 changed files with 147 additions and 142 deletions

View file

@ -28,7 +28,7 @@ var _ = Describe("Client", func() {
addr net.Addr
connID protocol.ConnectionID
originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (packetHandler, error)
originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error)
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -49,7 +49,7 @@ var _ = Describe("Client", func() {
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockPacketHandler(mockCtrl)
// sess = NewMockQuicSession(mockCtrl)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = newMockPacketConn()
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
@ -104,9 +104,9 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
remoteAddrChan <- conn.RemoteAddr().String()
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
@ -128,9 +128,9 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
hostnameChan <- h
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
@ -159,8 +159,8 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
sess := NewMockPacketHandler(mockCtrl)
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().handlePacket(gomock.Any())
runner.onHandshakeComplete(sess)
@ -187,8 +187,8 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
sess := NewMockPacketHandler(mockCtrl)
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { close(handledPacket) })
sess.EXPECT().run().Return(testErr)
return sess, nil
@ -202,7 +202,7 @@ var _ = Describe("Client", func() {
It("closes the session when the context is canceledd", func() {
sessionRunning := make(chan struct{})
defer close(sessionRunning)
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() {
<-sessionRunning
})
@ -217,7 +217,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
return sess, nil
}
ctx, cancel := context.WithCancel(context.Background())
@ -300,7 +300,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
return nil, testErr
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
@ -328,14 +328,14 @@ var _ = Describe("Client", func() {
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
cconn = connP
hostname = hostnameP
version = versionP
conf = configP
close(c)
// TODO: check connection IDs?
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
@ -374,9 +374,9 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
return sess, nil
}
@ -385,7 +385,7 @@ var _ = Describe("Client", func() {
})
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
ph := wire.Header{
@ -406,13 +406,13 @@ var _ = Describe("Client", func() {
version1 := protocol.Version39
version2 := protocol.Version39 + 1
Expect(version2.UsesTLS()).To(BeFalse())
sess1 := NewMockPacketHandler(mockCtrl)
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockPacketHandler(mockCtrl)
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockPacketHandler, 2)
sessionChan := make(chan *MockQuicSession, 2)
sessionChan <- sess1
sessionChan <- sess2
newClientSession = func(
@ -426,7 +426,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
return <-sessionChan, nil
}
@ -450,13 +450,13 @@ var _ = Describe("Client", func() {
version3 := protocol.Version39 + 2
Expect(version2.UsesTLS()).To(BeFalse())
Expect(version3.UsesTLS()).To(BeFalse())
sess1 := NewMockPacketHandler(mockCtrl)
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockPacketHandler(mockCtrl)
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockPacketHandler, 2)
sessionChan := make(chan *MockQuicSession, 2)
sessionChan <- sess1
sessionChan <- sess2
newClientSession = func(
@ -470,7 +470,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
return <-sessionChan, nil
}
@ -492,7 +492,7 @@ var _ = Describe("Client", func() {
})
It("errors if no matching version is found", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(gomock.Any())
cl.session = sess
cl.config = &Config{Versions: protocol.SupportedVersions}
@ -501,7 +501,7 @@ var _ = Describe("Client", func() {
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(gomock.Any())
cl.session = sess
v := protocol.VersionNumber(1234)
@ -512,7 +512,7 @@ var _ = Describe("Client", func() {
})
It("changes to the version preferred by the quic.Config", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(errCloseSessionForNewVersion)
cl.session = sess
config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
@ -532,14 +532,14 @@ var _ = Describe("Client", func() {
})
It("ignores packets with an invalid public header", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
err := cl.handlePacket(addr, []byte("invalid packet"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing packet from"))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
@ -555,7 +555,7 @@ var _ = Describe("Client", func() {
})
It("cuts packets at the payload length", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) {
Expect(packet.data).To(HaveLen(123))
})
@ -592,7 +592,7 @@ var _ = Describe("Client", func() {
})
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
cl.config = &Config{RequestConnectionIDOmission: false}
buf := &bytes.Buffer{}
err := (&wire.Header{
@ -608,7 +608,7 @@ var _ = Describe("Client", func() {
})
It("ignores packets with the wrong destination connection ID", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
buf := &bytes.Buffer{}
cl.version = versionIETFFrames
cl.config = &Config{RequestConnectionIDOmission: false}
@ -644,13 +644,13 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
cconn = connP
hostname = hostnameP
version = versionP
conf = configP
close(c)
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
@ -666,11 +666,11 @@ var _ = Describe("Client", func() {
It("creates a new session when the server performs a retry", func() {
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
sess1 := NewMockPacketHandler(mockCtrl)
sess1 := NewMockQuicSession(mockCtrl)
sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry)
sess2 := NewMockPacketHandler(mockCtrl)
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessions := []*MockPacketHandler{sess1, sess2}
sessions := []*MockQuicSession{sess1, sess2}
newTLSClientSession = func(
connP connection,
_ sessionRunner,
@ -683,7 +683,7 @@ var _ = Describe("Client", func() {
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
_ utils.Logger,
) (packetHandler, error) {
) (quicSession, error) {
sess := sessions[0]
sessions = sessions[1:]
return sess, nil
@ -749,7 +749,7 @@ var _ = Describe("Client", func() {
Context("handling packets", func() {
It("handles packets", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
ph := wire.Header{
@ -780,7 +780,7 @@ var _ = Describe("Client", func() {
It("closes the session when encountering an error while reading from the connection", func() {
testErr := errors.New("test error")
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().Close(testErr)
cl.session = sess
packetConn.readErr = testErr
@ -790,7 +790,7 @@ var _ = Describe("Client", func() {
Context("Public Reset handling", func() {
It("closes the session when receiving a Public Reset", func() {
sess := NewMockPacketHandler(mockCtrl)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
})
@ -800,14 +800,14 @@ var _ = Describe("Client", func() {
})
It("ignores Public Resets from the wrong remote address", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.destConnID, 1, 0))
Expect(err).To(MatchError("Received a spoofed Public Reset"))
})
It("ignores unparseable Public Resets", func() {
cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any calls
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
pr := wire.WritePublicReset(cl.destConnID, 1, 0)
err := cl.handlePacket(addr, pr[:len(pr)-5])
Expect(err).To(HaveOccurred())