correctly handle connection errors in the client

This commit is contained in:
Marten Seemann 2017-02-22 12:41:07 +07:00
parent 96edca5219
commit 8247454b0f
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
3 changed files with 22 additions and 13 deletions

View file

@ -83,7 +83,6 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version)
// TODO: handle errors
go c.Listen() go c.Listen()
c.mutex.Lock() c.mutex.Lock()
@ -111,17 +110,17 @@ func DialAddr(hostname string, config *Config) (Session, error) {
} }
// Listen listens // Listen listens
func (c *client) Listen() error { func (c *client) Listen() {
for { for {
data := getPacketBuffer() data := getPacketBuffer()
data = data[:protocol.MaxPacketSize] data = data[:protocol.MaxPacketSize]
n, addr, err := c.conn.Read(data) n, addr, err := c.conn.Read(data)
if err != nil { if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") { if !strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil c.session.Close(err)
} }
return err return
} }
data = data[:n] data = data[:n]
@ -129,7 +128,7 @@ func (c *client) Listen() error {
if err != nil { if err != nil {
utils.Errorf("error handling packet: %s", err.Error()) utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err) c.session.Close(err)
return err return
} }
} }
} }

View file

@ -74,9 +74,7 @@ var _ = Describe("Client", func() {
var stoppedListening bool var stoppedListening bool
go func() { go func() {
defer GinkgoRecover() cl.Listen()
err := cl.Listen()
Expect(err).ToNot(HaveOccurred())
stoppedListening = true stoppedListening = true
}() }()
@ -131,7 +129,7 @@ var _ = Describe("Client", func() {
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
var stoppedListening bool var stoppedListening bool
go func() { go func() {
_ = cl.Listen() cl.Listen()
// it should continue listening when receiving valid packets // it should continue listening when receiving valid packets
stoppedListening = true stoppedListening = true
}() }()
@ -142,11 +140,19 @@ var _ = Describe("Client", func() {
}) })
It("closes the session when encountering an error while handling a packet", func() { It("closes the session when encountering an error while handling a packet", func() {
Expect(sess.closeReason).ToNot(HaveOccurred())
packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100) packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100)
listenErr := cl.Listen() cl.Listen()
Expect(listenErr).To(HaveOccurred())
Expect(sess.closed).To(BeTrue()) Expect(sess.closed).To(BeTrue())
Expect(sess.closeReason).To(MatchError(listenErr)) Expect(sess.closeReason).To(HaveOccurred())
})
It("closes the session when encountering an error while reading from the connection", func() {
testErr := errors.New("test error")
packetConn.readErr = testErr
cl.Listen()
Expect(sess.closed).To(BeTrue())
Expect(sess.closeReason).To(MatchError(testErr))
}) })
}) })

View file

@ -14,12 +14,16 @@ type mockPacketConn struct {
addr net.Addr addr net.Addr
dataToRead []byte dataToRead []byte
dataReadFrom net.Addr dataReadFrom net.Addr
readErr error
dataWritten bytes.Buffer dataWritten bytes.Buffer
dataWrittenTo net.Addr dataWrittenTo net.Addr
closed bool closed bool
} }
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil {
return 0, nil, c.readErr
}
if c.dataToRead == nil { // block if there's no data if c.dataToRead == nil { // block if there's no data
time.Sleep(time.Hour) time.Sleep(time.Hour)
return 0, nil, io.EOF return 0, nil, io.EOF