implement closing the quic client with an error

This commit is contained in:
Marten Seemann 2016-12-17 10:11:31 +07:00
parent c42262c2b3
commit 08c267431b
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
5 changed files with 16 additions and 11 deletions

View file

@ -110,8 +110,8 @@ func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
} }
// Close closes the connection // Close closes the connection
func (c *Client) Close() error { func (c *Client) Close(e error) error {
_ = c.session.Close(nil) _ = c.session.Close(e)
return c.conn.Close() return c.conn.Close()
} }

View file

@ -3,7 +3,9 @@ package quic
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"runtime"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
@ -59,6 +61,7 @@ var _ = Describe("Client", func() {
}) })
It("properly closes the client", func(done Done) { It("properly closes the client", func(done Done) {
numGoRoutines := runtime.NumGoroutine()
startUDPConn() startUDPConn()
var stoppedListening bool var stoppedListening bool
go func() { go func() {
@ -67,11 +70,13 @@ var _ = Describe("Client", func() {
stoppedListening = true stoppedListening = true
}() }()
err := client.Close() testErr := errors.New("test error")
err := client.Close(testErr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(session.closed).Should(BeTrue()) Eventually(session.closed).Should(BeTrue())
Expect(session.closeReason).To(BeNil()) Expect(session.closeReason).To(MatchError(testErr))
Eventually(func() bool { return stoppedListening }).Should(BeTrue()) Eventually(func() bool { return stoppedListening }).Should(BeTrue())
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
close(done) close(done)
}) })
@ -85,7 +90,7 @@ var _ = Describe("Client", func() {
Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID)) Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID))
Expect(client.session.(*Session).version).To(Equal(client.version)) Expect(client.session.(*Session).version).To(Equal(client.version))
err = client.Close() err = client.Close(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -130,7 +135,7 @@ var _ = Describe("Client", func() {
Expect(session.closed).To(BeFalse()) Expect(session.closed).To(BeFalse())
Eventually(func() bool { return stoppedListening }).Should(BeFalse()) Eventually(func() bool { return stoppedListening }).Should(BeFalse())
err = client.Close() err = client.Close(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(done) close(done)
}) })
@ -202,7 +207,7 @@ var _ = Describe("Client", func() {
// it didn't pass the version negoation packet to the session (since it has no payload) // it didn't pass the version negoation packet to the session (since it has no payload)
Expect(session.packetCount).To(BeZero()) Expect(session.packetCount).To(BeZero())
err = client.Close() err = client.Close(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })

View file

@ -16,7 +16,7 @@ func main() {
} }
err = client.Listen() err = client.Listen()
defer client.Close() defer client.Close(nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -16,7 +16,7 @@ import (
type quicClient interface { type quicClient interface {
OpenStream(protocol.StreamID) (utils.Stream, error) OpenStream(protocol.StreamID) (utils.Stream, error)
Close() error Close(error) error
Listen() error Listen() error
} }

View file

@ -14,8 +14,8 @@ type mockQuicClient struct {
streams map[protocol.StreamID]*mockStream streams map[protocol.StreamID]*mockStream
} }
func (m *mockQuicClient) Close() error { panic("not implemented") } func (m *mockQuicClient) Close(error) error { panic("not implemented") }
func (m *mockQuicClient) Listen() error { panic("not implemented") } func (m *mockQuicClient) Listen() error { panic("not implemented") }
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) { func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
_, ok := m.streams[id] _, ok := m.streams[id]
if ok { if ok {