From 4ecbd572947b3dc41d089831612d7dd67140e91d Mon Sep 17 00:00:00 2001 From: Haruue Date: Sun, 22 Sep 2024 22:48:06 +0800 Subject: [PATCH 1/7] fix: quic sniff not work if udp msg fragmentated --- core/server/udp.go | 198 ++++++++++++++++++++++++++++++++------------- 1 file changed, 143 insertions(+), 55 deletions(-) diff --git a/core/server/udp.go b/core/server/udp.go index ecaee29..0ec0d5e 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -31,11 +31,62 @@ type udpEventLogger interface { type udpSessionEntry struct { ID uint32 - Conn UDPConn OverrideAddr string // Ignore the address in the UDP message, always use this if not empty + OriginalAddr string // The original address in the UDP message D *frag.Defragger Last *utils.AtomicTime - Timeout bool // true if the session is closed due to timeout + IO udpIO + + DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) + ExitFunc func(err error) + + timeoutChan chan struct{} + exitChan chan error + + conn UDPConn + connLock sync.Mutex + closed bool +} + +func newUDPSessionEntry( + id uint32, io udpIO, + dialFunc func(string, []byte) (UDPConn, string, error), + exitFunc func(error), +) (e *udpSessionEntry) { + e = &udpSessionEntry{ + ID: id, + D: &frag.Defragger{}, + Last: utils.NewAtomicTime(time.Now()), + IO: io, + + DialFunc: dialFunc, + ExitFunc: exitFunc, + + timeoutChan: make(chan struct{}), + exitChan: make(chan error, 2), + } + + go func() { + // Guard routine + var err error + select { + case <-e.timeoutChan: + // Use nil error to indicate timeout. + case err = <-e.exitChan: + } + + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + e.closed = true + if e.conn != nil { + _ = e.conn.Close() + } + e.connLock.Unlock() + + e.ExitFunc(err) + }() + + return } // Feed feeds a UDP message to the session. @@ -49,27 +100,72 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) { if dfMsg == nil { return 0, nil } - if e.OverrideAddr != "" { - return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr) - } else { - return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) + + if e.conn == nil { + err := e.initConn(dfMsg) + if err != nil { + return 0, err + } } + + addr := dfMsg.Addr + if e.OverrideAddr != "" { + addr = e.OverrideAddr + } + + return e.conn.WriteTo(dfMsg.Data, addr) } -// ReceiveLoop receives incoming UDP packets, packs them into UDP messages, -// and sends using the provided io. -// Exit and returns error when either the underlying UDP connection returns -// error (e.g. closed), or the provided io returns error when sending. -func (e *udpSessionEntry) ReceiveLoop(io udpIO) error { +// initConn initializes the UDP connection of the session. +// If no error is returned, the e.conn is set to the new connection. +func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + defer e.connLock.Unlock() + + if e.closed { + return errors.New("session is closed") + } + + conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data) + if err != nil { + // Fail fast if DailFunc failed + // (usually indicates the connection has been rejected by the ACL) + e.exitChan <- err + return err + } + + e.conn = conn + if firstMsg.Addr != actualAddr { + e.OverrideAddr = actualAddr + e.OriginalAddr = firstMsg.Addr + } + go e.receiveLoop() + return nil +} + +// receiveLoop receives incoming UDP packets, packs them into UDP messages, +// and sends using the IO. +// Exit when either the underlying UDP connection returns error (e.g. closed), +// or the IO returns error when sending. +func (e *udpSessionEntry) receiveLoop() { udpBuf := make([]byte, protocol.MaxUDPSize) msgBuf := make([]byte, protocol.MaxUDPSize) for { - udpN, rAddr, err := e.Conn.ReadFrom(udpBuf) + udpN, rAddr, err := e.conn.ReadFrom(udpBuf) if err != nil { - return err + e.exitChan <- err + return } e.Last.Set(time.Now()) + if e.OriginalAddr != "" { + // Use the original address in the opposite direction, + // otherwise the QUIC clients or NAT on the client side + // may not treat it as the same UDP session. + rAddr = e.OriginalAddr + } + msg := &protocol.UDPMessage{ SessionID: e.ID, PacketID: 0, @@ -78,13 +174,23 @@ func (e *udpSessionEntry) ReceiveLoop(io udpIO) error { Addr: rAddr, Data: udpBuf[:udpN], } - err = sendMessageAutoFrag(io, msgBuf, msg) + err = sendMessageAutoFrag(e.IO, msgBuf, msg) if err != nil { - return err + e.exitChan <- err + return } } } +// MarkTimeout marks the session to be cleaned up due to timeout. +// Should only be called by the cleanup routine of the session manager. +func (e *udpSessionEntry) MarkTimeout() { + select { + case e.timeoutChan <- struct{}{}: + default: + } +} + // sendMessageAutoFrag tries to send a UDP message as a whole first, // but if it fails due to quic.ErrMessageTooLarge, it tries again by // fragmenting the message. @@ -168,10 +274,8 @@ func (m *udpSessionManager) cleanup(idleOnly bool) { now := time.Now() for _, entry := range m.m { if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout { - entry.Timeout = true - _ = entry.Conn.Close() - // Closing the connection here will cause the ReceiveLoop to exit, - // and the session will be removed from the map there. + entry.MarkTimeout() + // Entry will be removed by its ExitFunc. } } } @@ -183,47 +287,31 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { // Create a new session if not exists if entry == nil { - // Call the hook - origMsgAddr := msg.Addr - err := m.io.Hook(msg.Data, &msg.Addr) - if err != nil { - return - } - // Log the event - m.eventLogger.New(msg.SessionID, msg.Addr) - // Dial target & create a new session entry - conn, err := m.io.UDP(msg.Addr) - if err != nil { - m.eventLogger.Close(msg.SessionID, err) - return - } - entry = &udpSessionEntry{ - ID: msg.SessionID, - Conn: conn, - D: &frag.Defragger{}, - Last: utils.NewAtomicTime(time.Now()), - } - if origMsgAddr != msg.Addr { - // Hook changed the address, enable address override - entry.OverrideAddr = msg.Addr - } - // Start the receive loop for this session - go func() { - err := entry.ReceiveLoop(m.io) - if !entry.Timeout { - _ = entry.Conn.Close() - m.eventLogger.Close(entry.ID, err) - } else { - // Connection already closed by timeout cleanup, - // no need to close again here. - // Use nil error to indicate timeout. - m.eventLogger.Close(entry.ID, nil) + dialFunc := func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) { + // Call the hook + err = m.io.Hook(firstMsgData, &addr) + if err != nil { + return } + actualAddr = addr + // Log the event + m.eventLogger.New(msg.SessionID, addr) + // Dial target + conn, err = m.io.UDP(addr) + return + } + exitFunc := func(err error) { + // Log the event + m.eventLogger.Close(entry.ID, err) + // Remove the session from the map m.mutex.Lock() delete(m.m, entry.ID) m.mutex.Unlock() - }() + } + + entry = newUDPSessionEntry(msg.SessionID, m.io, dialFunc, exitFunc) + // Insert the session into the map m.mutex.Lock() m.m[msg.SessionID] = entry From 931fc2fdb2bbbba0e78f131e020898cc597abb0b Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 4 Oct 2024 11:27:36 +0800 Subject: [PATCH 2/7] chore: replace guard routine with CloseWithErr() --- core/server/udp.go | 64 +++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/core/server/udp.go b/core/server/udp.go index 0ec0d5e..3d7cd71 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -40,9 +40,6 @@ type udpSessionEntry struct { DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) ExitFunc func(err error) - timeoutChan chan struct{} - exitChan chan error - conn UDPConn connLock sync.Mutex closed bool @@ -61,34 +58,30 @@ func newUDPSessionEntry( DialFunc: dialFunc, ExitFunc: exitFunc, - - timeoutChan: make(chan struct{}), - exitChan: make(chan error, 2), } - go func() { - // Guard routine - var err error - select { - case <-e.timeoutChan: - // Use nil error to indicate timeout. - case err = <-e.exitChan: - } - - // We need this lock to ensure not to create conn after session exit - e.connLock.Lock() - e.closed = true - if e.conn != nil { - _ = e.conn.Close() - } - e.connLock.Unlock() - - e.ExitFunc(err) - }() - return } +func (e *udpSessionEntry) CloseWithErr(err error) { + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + + if e.closed { + // Already closed + e.connLock.Unlock() + return + } + + e.closed = true + if e.conn != nil { + _ = e.conn.Close() + } + e.connLock.Unlock() + + e.ExitFunc(err) +} + // Feed feeds a UDP message to the session. // If the message itself is a complete message, or it completes a fragmented message, // the message is written to the session's UDP connection, and the number of bytes @@ -121,17 +114,18 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) { func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { // We need this lock to ensure not to create conn after session exit e.connLock.Lock() - defer e.connLock.Unlock() if e.closed { + e.connLock.Unlock() return errors.New("session is closed") } conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data) if err != nil { - // Fail fast if DailFunc failed + // Fail fast if DialFunc failed // (usually indicates the connection has been rejected by the ACL) - e.exitChan <- err + e.connLock.Unlock() + e.CloseWithErr(err) return err } @@ -141,6 +135,8 @@ func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { e.OriginalAddr = firstMsg.Addr } go e.receiveLoop() + + e.connLock.Unlock() return nil } @@ -154,7 +150,7 @@ func (e *udpSessionEntry) receiveLoop() { for { udpN, rAddr, err := e.conn.ReadFrom(udpBuf) if err != nil { - e.exitChan <- err + e.CloseWithErr(err) return } e.Last.Set(time.Now()) @@ -176,7 +172,7 @@ func (e *udpSessionEntry) receiveLoop() { } err = sendMessageAutoFrag(e.IO, msgBuf, msg) if err != nil { - e.exitChan <- err + e.CloseWithErr(err) return } } @@ -185,10 +181,8 @@ func (e *udpSessionEntry) receiveLoop() { // MarkTimeout marks the session to be cleaned up due to timeout. // Should only be called by the cleanup routine of the session manager. func (e *udpSessionEntry) MarkTimeout() { - select { - case e.timeoutChan <- struct{}{}: - default: - } + // nil error indicates timeout. + e.CloseWithErr(nil) } // sendMessageAutoFrag tries to send a UDP message as a whole first, From dc023ae13a2ed16c260ab566f9086cc56e8685bd Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 4 Oct 2024 16:33:41 +0800 Subject: [PATCH 3/7] fix: udpSessionManager.mutex reentrant by cleanup --- core/server/udp.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/server/udp.go b/core/server/udp.go index 3d7cd71..ec470fc 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -262,16 +262,21 @@ func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) { func (m *udpSessionManager) cleanup(idleOnly bool) { // We use RLock here as we are only scanning the map, not deleting from it. - m.mutex.RLock() - defer m.mutex.RUnlock() + timeoutEntry := make([]*udpSessionEntry, 0, len(m.m)) + m.mutex.RLock() now := time.Now() for _, entry := range m.m { if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout { - entry.MarkTimeout() - // Entry will be removed by its ExitFunc. + timeoutEntry = append(timeoutEntry, entry) } } + m.mutex.RUnlock() + + for _, entry := range timeoutEntry { + entry.MarkTimeout() + // Entry will be removed by its ExitFunc. + } } func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { From 4e2f138008c7fa91f4c84624f074207389ddd11f Mon Sep 17 00:00:00 2001 From: Haruue Date: Fri, 4 Oct 2024 16:40:15 +0800 Subject: [PATCH 4/7] chore: fix comments --- core/server/udp.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/core/server/udp.go b/core/server/udp.go index ec470fc..eb10b19 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -63,6 +63,8 @@ func newUDPSessionEntry( return } +// CloseWithErr closes the session and calls ExitFunc with the given error. +// A nil error indicates the session is cleaned up due to timeout. func (e *udpSessionEntry) CloseWithErr(err error) { // We need this lock to ensure not to create conn after session exit e.connLock.Lock() @@ -125,6 +127,7 @@ func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { // Fail fast if DialFunc failed // (usually indicates the connection has been rejected by the ACL) e.connLock.Unlock() + // CloseWithErr acquires the connLock again e.CloseWithErr(err) return err } @@ -178,13 +181,6 @@ func (e *udpSessionEntry) receiveLoop() { } } -// MarkTimeout marks the session to be cleaned up due to timeout. -// Should only be called by the cleanup routine of the session manager. -func (e *udpSessionEntry) MarkTimeout() { - // nil error indicates timeout. - e.CloseWithErr(nil) -} - // sendMessageAutoFrag tries to send a UDP message as a whole first, // but if it fails due to quic.ErrMessageTooLarge, it tries again by // fragmenting the message. @@ -261,9 +257,9 @@ func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) { } func (m *udpSessionManager) cleanup(idleOnly bool) { - // We use RLock here as we are only scanning the map, not deleting from it. timeoutEntry := make([]*udpSessionEntry, 0, len(m.m)) + // We use RLock here as we are only scanning the map, not deleting from it. m.mutex.RLock() now := time.Now() for _, entry := range m.m { @@ -274,8 +270,9 @@ func (m *udpSessionManager) cleanup(idleOnly bool) { m.mutex.RUnlock() for _, entry := range timeoutEntry { - entry.MarkTimeout() - // Entry will be removed by its ExitFunc. + // This eventually calls entry.ExitFunc, + // where the m.mutex will be locked again to remove the entry from the map. + entry.CloseWithErr(nil) } } From 947701897b0562230f133b6cc81674f6380bb898 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 4 Oct 2024 10:29:25 -0700 Subject: [PATCH 5/7] fix: TestClientServerHookUDP --- core/internal/integration_tests/hook_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/internal/integration_tests/hook_test.go b/core/internal/integration_tests/hook_test.go index 1121d13..64affe8 100644 --- a/core/internal/integration_tests/hook_test.go +++ b/core/internal/integration_tests/hook_test.go @@ -132,7 +132,9 @@ func TestClientServerHookUDP(t *testing.T) { rData, rAddr, err := conn.Receive() assert.NoError(t, err) assert.Equal(t, sData, rData) - assert.Equal(t, realEchoAddr, rAddr) + // Hook address change is transparent, + // the client should still see the fake echo address it sent packets to + assert.Equal(t, fakeEchoAddr, rAddr) // Subsequent packets should also be sent to the real echo server sData = []byte("never stop fighting") @@ -141,5 +143,5 @@ func TestClientServerHookUDP(t *testing.T) { rData, rAddr, err = conn.Receive() assert.NoError(t, err) assert.Equal(t, sData, rData) - assert.Equal(t, realEchoAddr, rAddr) + assert.Equal(t, fakeEchoAddr, rAddr) } From b3116c62682871a3a15234c1493a79aad805b196 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 4 Oct 2024 10:47:41 -0700 Subject: [PATCH 6/7] feat: update TestUDPSessionManager to cover the fragmented msg hook --- core/server/udp_test.go | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/core/server/udp_test.go b/core/server/udp_test.go index 045edbd..8aa899f 100644 --- a/core/server/udp_test.go +++ b/core/server/udp_test.go @@ -25,7 +25,6 @@ func TestUDPSessionManager(t *testing.T) { } return m, nil }) - io.EXPECT().Hook(mock.Anything, mock.Anything).Return(nil) go sm.Run() @@ -50,6 +49,7 @@ func TestUDPSessionManager(t *testing.T) { eventLogger.EXPECT().New(msg1.SessionID, msg1.Addr).Return().Once() udpConn1 := newMockUDPConn(t) udpConn1Ch := make(chan []byte, 1) + io.EXPECT().Hook(msg1.Data, &msg1.Addr).Return(nil).Once() io.EXPECT().UDP(msg1.Addr).Return(udpConn1, nil).Once() udpConn1.EXPECT().WriteTo(msg1.Data, msg1.Addr).Return(5, nil).Once() udpConn1.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, string, error) { @@ -66,31 +66,44 @@ func TestUDPSessionManager(t *testing.T) { msgCh <- msg1 udpConn1Ch <- []byte("hi back") - msg2 := &protocol.UDPMessage{ + msg2data := []byte("how are you doing?") + msg2_1 := &protocol.UDPMessage{ SessionID: 5678, PacketID: 0, FragID: 0, - FragCount: 1, + FragCount: 2, Addr: "address2.net:12450", - Data: []byte("how are you"), + Data: msg2data[:6], } - eventLogger.EXPECT().New(msg2.SessionID, msg2.Addr).Return().Once() + msg2_2 := &protocol.UDPMessage{ + SessionID: 5678, + PacketID: 0, + FragID: 1, + FragCount: 2, + Addr: "address2.net:12450", + Data: msg2data[6:], + } + + eventLogger.EXPECT().New(msg2_1.SessionID, msg2_1.Addr).Return().Once() udpConn2 := newMockUDPConn(t) udpConn2Ch := make(chan []byte, 1) - io.EXPECT().UDP(msg2.Addr).Return(udpConn2, nil).Once() - udpConn2.EXPECT().WriteTo(msg2.Data, msg2.Addr).Return(11, nil).Once() + // On fragmentation, make sure hook gets the whole message + io.EXPECT().Hook(msg2data, &msg2_1.Addr).Return(nil).Once() + io.EXPECT().UDP(msg2_1.Addr).Return(udpConn2, nil).Once() + udpConn2.EXPECT().WriteTo(msg2data, msg2_1.Addr).Return(11, nil).Once() udpConn2.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, string, error) { - return udpReadFunc(msg2.Addr, udpConn2Ch, b) + return udpReadFunc(msg2_1.Addr, udpConn2Ch, b) }) io.EXPECT().SendMessage(mock.Anything, &protocol.UDPMessage{ - SessionID: msg2.SessionID, + SessionID: msg2_1.SessionID, PacketID: 0, FragID: 0, FragCount: 1, - Addr: msg2.Addr, + Addr: msg2_1.Addr, Data: []byte("im fine"), }).Return(nil).Once() - msgCh <- msg2 + msgCh <- msg2_1 + msgCh <- msg2_2 udpConn2Ch <- []byte("im fine") msg3 := &protocol.UDPMessage{ @@ -123,7 +136,7 @@ func TestUDPSessionManager(t *testing.T) { return nil }).Once() eventLogger.EXPECT().Close(msg1.SessionID, nil).Once() - eventLogger.EXPECT().Close(msg2.SessionID, nil).Once() + eventLogger.EXPECT().Close(msg2_1.SessionID, nil).Once() time.Sleep(3 * time.Second) // Wait for timeout mock.AssertExpectationsForObjects(t, io, eventLogger, udpConn1, udpConn2) @@ -140,6 +153,7 @@ func TestUDPSessionManager(t *testing.T) { } eventLogger.EXPECT().New(msg4.SessionID, msg4.Addr).Return().Once() udpConn4 := newMockUDPConn(t) + io.EXPECT().Hook(msg4.Data, &msg4.Addr).Return(nil).Once() io.EXPECT().UDP(msg4.Addr).Return(udpConn4, nil).Once() udpConn4.EXPECT().WriteTo(msg4.Data, msg4.Addr).Return(12, nil).Once() udpConn4.EXPECT().ReadFrom(mock.Anything).Return(0, "", errUDPClosed).Once() @@ -161,6 +175,7 @@ func TestUDPSessionManager(t *testing.T) { Data: []byte("babe i miss you"), } eventLogger.EXPECT().New(msg5.SessionID, msg5.Addr).Return().Once() + io.EXPECT().Hook(msg5.Data, &msg5.Addr).Return(nil).Once() io.EXPECT().UDP(msg5.Addr).Return(nil, errUDPIO).Once() eventLogger.EXPECT().Close(msg5.SessionID, errUDPIO).Once() msgCh <- msg5 From 1001b2b1adb5eb21e7f3659e58b353aed42239d3 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 5 Oct 2024 10:23:43 +0800 Subject: [PATCH 7/7] chore: fix comments --- core/server/udp.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/server/udp.go b/core/server/udp.go index eb10b19..14efc9e 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -133,7 +133,9 @@ func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { } e.conn = conn + if firstMsg.Addr != actualAddr { + // Hook changed the address, enable address override e.OverrideAddr = actualAddr e.OriginalAddr = firstMsg.Addr }