diff --git a/core/client/reconnect.go b/core/client/reconnect.go index 137285f..05d60b3 100644 --- a/core/client/reconnect.go +++ b/core/client/reconnect.go @@ -56,53 +56,56 @@ func (rc *reconnectableClientImpl) reconnect() error { } } -func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { +// clientDo calls f with the current client. +// If the client is nil, it will first reconnect. +// It will also detect if the client is closed, and if so, +// set it to nil for reconnect next time. +func (rc *reconnectableClientImpl) clientDo(f func(Client) (interface{}, error)) (interface{}, error) { rc.m.Lock() - defer rc.m.Unlock() if rc.closed { + rc.m.Unlock() return nil, coreErrs.ClosedError{} } if rc.client == nil { // No active connection, connect first if err := rc.reconnect(); err != nil { + rc.m.Unlock() return nil, err } } - conn, err := rc.client.TCP(addr) + client := rc.client + rc.m.Unlock() + + ret, err := f(client) if _, ok := err.(coreErrs.ClosedError); ok { - // Connection closed, reconnect - if err := rc.reconnect(); err != nil { - return nil, err + // Connection closed, set client to nil for reconnect next time + rc.m.Lock() + if rc.client == client { + // This check is in case the client is already changed by another goroutine + rc.client = nil } - return rc.client.TCP(addr) + rc.m.Unlock() + } + return ret, err +} + +func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.TCP(addr) + }); err != nil { + return nil, err } else { - // OK or some other temporary error - return conn, err + return c.(net.Conn), nil } } func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { - rc.m.Lock() - defer rc.m.Unlock() - if rc.closed { - return nil, coreErrs.ClosedError{} - } - if rc.client == nil { - // No active connection, connect first - if err := rc.reconnect(); err != nil { - return nil, err - } - } - conn, err := rc.client.UDP() - if _, ok := err.(coreErrs.ClosedError); ok { - // Connection closed, reconnect - if err := rc.reconnect(); err != nil { - return nil, err - } - return rc.client.UDP() + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.UDP() + }); err != nil { + return nil, err } else { - // OK or some other temporary error - return conn, err + return c.(HyUDPConn), nil } }