diff --git a/core/server/udp.go b/core/server/udp.go index 0ec0d5e..3d7cd71 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -40,9 +40,6 @@ type udpSessionEntry struct { DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) ExitFunc func(err error) - timeoutChan chan struct{} - exitChan chan error - conn UDPConn connLock sync.Mutex closed bool @@ -61,34 +58,30 @@ func newUDPSessionEntry( DialFunc: dialFunc, ExitFunc: exitFunc, - - timeoutChan: make(chan struct{}), - exitChan: make(chan error, 2), } - go func() { - // Guard routine - var err error - select { - case <-e.timeoutChan: - // Use nil error to indicate timeout. - case err = <-e.exitChan: - } - - // We need this lock to ensure not to create conn after session exit - e.connLock.Lock() - e.closed = true - if e.conn != nil { - _ = e.conn.Close() - } - e.connLock.Unlock() - - e.ExitFunc(err) - }() - return } +func (e *udpSessionEntry) CloseWithErr(err error) { + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + + if e.closed { + // Already closed + e.connLock.Unlock() + return + } + + e.closed = true + if e.conn != nil { + _ = e.conn.Close() + } + e.connLock.Unlock() + + e.ExitFunc(err) +} + // Feed feeds a UDP message to the session. // If the message itself is a complete message, or it completes a fragmented message, // the message is written to the session's UDP connection, and the number of bytes @@ -121,17 +114,18 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) { func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { // We need this lock to ensure not to create conn after session exit e.connLock.Lock() - defer e.connLock.Unlock() if e.closed { + e.connLock.Unlock() return errors.New("session is closed") } conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data) if err != nil { - // Fail fast if DailFunc failed + // Fail fast if DialFunc failed // (usually indicates the connection has been rejected by the ACL) - e.exitChan <- err + e.connLock.Unlock() + e.CloseWithErr(err) return err } @@ -141,6 +135,8 @@ func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { e.OriginalAddr = firstMsg.Addr } go e.receiveLoop() + + e.connLock.Unlock() return nil } @@ -154,7 +150,7 @@ func (e *udpSessionEntry) receiveLoop() { for { udpN, rAddr, err := e.conn.ReadFrom(udpBuf) if err != nil { - e.exitChan <- err + e.CloseWithErr(err) return } e.Last.Set(time.Now()) @@ -176,7 +172,7 @@ func (e *udpSessionEntry) receiveLoop() { } err = sendMessageAutoFrag(e.IO, msgBuf, msg) if err != nil { - e.exitChan <- err + e.CloseWithErr(err) return } } @@ -185,10 +181,8 @@ func (e *udpSessionEntry) receiveLoop() { // MarkTimeout marks the session to be cleaned up due to timeout. // Should only be called by the cleanup routine of the session manager. func (e *udpSessionEntry) MarkTimeout() { - select { - case e.timeoutChan <- struct{}{}: - default: - } + // nil error indicates timeout. + e.CloseWithErr(nil) } // sendMessageAutoFrag tries to send a UDP message as a whole first,