mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-02 03:57:38 +03:00
Merge pull request #1206 from apernet/fix-quic-sniff
fix: quic sniff not work if udp msg fragmentated
This commit is contained in:
commit
ecc95fb973
3 changed files with 174 additions and 71 deletions
|
@ -132,7 +132,9 @@ func TestClientServerHookUDP(t *testing.T) {
|
|||
rData, rAddr, err := conn.Receive()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sData, rData)
|
||||
assert.Equal(t, realEchoAddr, rAddr)
|
||||
// Hook address change is transparent,
|
||||
// the client should still see the fake echo address it sent packets to
|
||||
assert.Equal(t, fakeEchoAddr, rAddr)
|
||||
|
||||
// Subsequent packets should also be sent to the real echo server
|
||||
sData = []byte("never stop fighting")
|
||||
|
@ -141,5 +143,5 @@ func TestClientServerHookUDP(t *testing.T) {
|
|||
rData, rAddr, err = conn.Receive()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sData, rData)
|
||||
assert.Equal(t, realEchoAddr, rAddr)
|
||||
assert.Equal(t, fakeEchoAddr, rAddr)
|
||||
}
|
||||
|
|
|
@ -31,11 +31,57 @@ type udpEventLogger interface {
|
|||
|
||||
type udpSessionEntry struct {
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
|
||||
OriginalAddr string // The original address in the UDP message
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Timeout bool // true if the session is closed due to timeout
|
||||
IO udpIO
|
||||
|
||||
DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error)
|
||||
ExitFunc func(err error)
|
||||
|
||||
conn UDPConn
|
||||
connLock sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newUDPSessionEntry(
|
||||
id uint32, io udpIO,
|
||||
dialFunc func(string, []byte) (UDPConn, string, error),
|
||||
exitFunc func(error),
|
||||
) (e *udpSessionEntry) {
|
||||
e = &udpSessionEntry{
|
||||
ID: id,
|
||||
D: &frag.Defragger{},
|
||||
Last: utils.NewAtomicTime(time.Now()),
|
||||
IO: io,
|
||||
|
||||
DialFunc: dialFunc,
|
||||
ExitFunc: exitFunc,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CloseWithErr closes the session and calls ExitFunc with the given error.
|
||||
// A nil error indicates the session is cleaned up due to timeout.
|
||||
func (e *udpSessionEntry) CloseWithErr(err error) {
|
||||
// We need this lock to ensure not to create conn after session exit
|
||||
e.connLock.Lock()
|
||||
|
||||
if e.closed {
|
||||
// Already closed
|
||||
e.connLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
e.closed = true
|
||||
if e.conn != nil {
|
||||
_ = e.conn.Close()
|
||||
}
|
||||
e.connLock.Unlock()
|
||||
|
||||
e.ExitFunc(err)
|
||||
}
|
||||
|
||||
// Feed feeds a UDP message to the session.
|
||||
|
@ -49,27 +95,78 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
|
|||
if dfMsg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if e.OverrideAddr != "" {
|
||||
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
|
||||
} else {
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
|
||||
if e.conn == nil {
|
||||
err := e.initConn(dfMsg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
addr := dfMsg.Addr
|
||||
if e.OverrideAddr != "" {
|
||||
addr = e.OverrideAddr
|
||||
}
|
||||
|
||||
return e.conn.WriteTo(dfMsg.Data, addr)
|
||||
}
|
||||
|
||||
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||
// and sends using the provided io.
|
||||
// Exit and returns error when either the underlying UDP connection returns
|
||||
// error (e.g. closed), or the provided io returns error when sending.
|
||||
func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
|
||||
// initConn initializes the UDP connection of the session.
|
||||
// If no error is returned, the e.conn is set to the new connection.
|
||||
func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error {
|
||||
// We need this lock to ensure not to create conn after session exit
|
||||
e.connLock.Lock()
|
||||
|
||||
if e.closed {
|
||||
e.connLock.Unlock()
|
||||
return errors.New("session is closed")
|
||||
}
|
||||
|
||||
conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data)
|
||||
if err != nil {
|
||||
// Fail fast if DialFunc failed
|
||||
// (usually indicates the connection has been rejected by the ACL)
|
||||
e.connLock.Unlock()
|
||||
// CloseWithErr acquires the connLock again
|
||||
e.CloseWithErr(err)
|
||||
return err
|
||||
}
|
||||
|
||||
e.conn = conn
|
||||
|
||||
if firstMsg.Addr != actualAddr {
|
||||
// Hook changed the address, enable address override
|
||||
e.OverrideAddr = actualAddr
|
||||
e.OriginalAddr = firstMsg.Addr
|
||||
}
|
||||
go e.receiveLoop()
|
||||
|
||||
e.connLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||
// and sends using the IO.
|
||||
// Exit when either the underlying UDP connection returns error (e.g. closed),
|
||||
// or the IO returns error when sending.
|
||||
func (e *udpSessionEntry) receiveLoop() {
|
||||
udpBuf := make([]byte, protocol.MaxUDPSize)
|
||||
msgBuf := make([]byte, protocol.MaxUDPSize)
|
||||
for {
|
||||
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
|
||||
udpN, rAddr, err := e.conn.ReadFrom(udpBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
e.CloseWithErr(err)
|
||||
return
|
||||
}
|
||||
e.Last.Set(time.Now())
|
||||
|
||||
if e.OriginalAddr != "" {
|
||||
// Use the original address in the opposite direction,
|
||||
// otherwise the QUIC clients or NAT on the client side
|
||||
// may not treat it as the same UDP session.
|
||||
rAddr = e.OriginalAddr
|
||||
}
|
||||
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: e.ID,
|
||||
PacketID: 0,
|
||||
|
@ -78,9 +175,10 @@ func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
|
|||
Addr: rAddr,
|
||||
Data: udpBuf[:udpN],
|
||||
}
|
||||
err = sendMessageAutoFrag(io, msgBuf, msg)
|
||||
err = sendMessageAutoFrag(e.IO, msgBuf, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
e.CloseWithErr(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -161,19 +259,23 @@ func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) {
|
|||
}
|
||||
|
||||
func (m *udpSessionManager) cleanup(idleOnly bool) {
|
||||
timeoutEntry := make([]*udpSessionEntry, 0, len(m.m))
|
||||
|
||||
// We use RLock here as we are only scanning the map, not deleting from it.
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, entry := range m.m {
|
||||
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
|
||||
entry.Timeout = true
|
||||
_ = entry.Conn.Close()
|
||||
// Closing the connection here will cause the ReceiveLoop to exit,
|
||||
// and the session will be removed from the map there.
|
||||
timeoutEntry = append(timeoutEntry, entry)
|
||||
}
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
for _, entry := range timeoutEntry {
|
||||
// This eventually calls entry.ExitFunc,
|
||||
// where the m.mutex will be locked again to remove the entry from the map.
|
||||
entry.CloseWithErr(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||
|
@ -183,47 +285,31 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
|
||||
// Create a new session if not exists
|
||||
if entry == nil {
|
||||
// Call the hook
|
||||
origMsgAddr := msg.Addr
|
||||
err := m.io.Hook(msg.Data, &msg.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Log the event
|
||||
m.eventLogger.New(msg.SessionID, msg.Addr)
|
||||
// Dial target & create a new session entry
|
||||
conn, err := m.io.UDP(msg.Addr)
|
||||
if err != nil {
|
||||
m.eventLogger.Close(msg.SessionID, err)
|
||||
return
|
||||
}
|
||||
entry = &udpSessionEntry{
|
||||
ID: msg.SessionID,
|
||||
Conn: conn,
|
||||
D: &frag.Defragger{},
|
||||
Last: utils.NewAtomicTime(time.Now()),
|
||||
}
|
||||
if origMsgAddr != msg.Addr {
|
||||
// Hook changed the address, enable address override
|
||||
entry.OverrideAddr = msg.Addr
|
||||
}
|
||||
// Start the receive loop for this session
|
||||
go func() {
|
||||
err := entry.ReceiveLoop(m.io)
|
||||
if !entry.Timeout {
|
||||
_ = entry.Conn.Close()
|
||||
m.eventLogger.Close(entry.ID, err)
|
||||
} else {
|
||||
// Connection already closed by timeout cleanup,
|
||||
// no need to close again here.
|
||||
// Use nil error to indicate timeout.
|
||||
m.eventLogger.Close(entry.ID, nil)
|
||||
dialFunc := func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) {
|
||||
// Call the hook
|
||||
err = m.io.Hook(firstMsgData, &addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
actualAddr = addr
|
||||
// Log the event
|
||||
m.eventLogger.New(msg.SessionID, addr)
|
||||
// Dial target
|
||||
conn, err = m.io.UDP(addr)
|
||||
return
|
||||
}
|
||||
exitFunc := func(err error) {
|
||||
// Log the event
|
||||
m.eventLogger.Close(entry.ID, err)
|
||||
|
||||
// Remove the session from the map
|
||||
m.mutex.Lock()
|
||||
delete(m.m, entry.ID)
|
||||
m.mutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
entry = newUDPSessionEntry(msg.SessionID, m.io, dialFunc, exitFunc)
|
||||
|
||||
// Insert the session into the map
|
||||
m.mutex.Lock()
|
||||
m.m[msg.SessionID] = entry
|
||||
|
|
|
@ -25,7 +25,6 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
}
|
||||
return m, nil
|
||||
})
|
||||
io.EXPECT().Hook(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
go sm.Run()
|
||||
|
||||
|
@ -50,6 +49,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
eventLogger.EXPECT().New(msg1.SessionID, msg1.Addr).Return().Once()
|
||||
udpConn1 := newMockUDPConn(t)
|
||||
udpConn1Ch := make(chan []byte, 1)
|
||||
io.EXPECT().Hook(msg1.Data, &msg1.Addr).Return(nil).Once()
|
||||
io.EXPECT().UDP(msg1.Addr).Return(udpConn1, nil).Once()
|
||||
udpConn1.EXPECT().WriteTo(msg1.Data, msg1.Addr).Return(5, nil).Once()
|
||||
udpConn1.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, string, error) {
|
||||
|
@ -66,31 +66,44 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
msgCh <- msg1
|
||||
udpConn1Ch <- []byte("hi back")
|
||||
|
||||
msg2 := &protocol.UDPMessage{
|
||||
msg2data := []byte("how are you doing?")
|
||||
msg2_1 := &protocol.UDPMessage{
|
||||
SessionID: 5678,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
FragCount: 2,
|
||||
Addr: "address2.net:12450",
|
||||
Data: []byte("how are you"),
|
||||
Data: msg2data[:6],
|
||||
}
|
||||
eventLogger.EXPECT().New(msg2.SessionID, msg2.Addr).Return().Once()
|
||||
msg2_2 := &protocol.UDPMessage{
|
||||
SessionID: 5678,
|
||||
PacketID: 0,
|
||||
FragID: 1,
|
||||
FragCount: 2,
|
||||
Addr: "address2.net:12450",
|
||||
Data: msg2data[6:],
|
||||
}
|
||||
|
||||
eventLogger.EXPECT().New(msg2_1.SessionID, msg2_1.Addr).Return().Once()
|
||||
udpConn2 := newMockUDPConn(t)
|
||||
udpConn2Ch := make(chan []byte, 1)
|
||||
io.EXPECT().UDP(msg2.Addr).Return(udpConn2, nil).Once()
|
||||
udpConn2.EXPECT().WriteTo(msg2.Data, msg2.Addr).Return(11, nil).Once()
|
||||
// On fragmentation, make sure hook gets the whole message
|
||||
io.EXPECT().Hook(msg2data, &msg2_1.Addr).Return(nil).Once()
|
||||
io.EXPECT().UDP(msg2_1.Addr).Return(udpConn2, nil).Once()
|
||||
udpConn2.EXPECT().WriteTo(msg2data, msg2_1.Addr).Return(11, nil).Once()
|
||||
udpConn2.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, string, error) {
|
||||
return udpReadFunc(msg2.Addr, udpConn2Ch, b)
|
||||
return udpReadFunc(msg2_1.Addr, udpConn2Ch, b)
|
||||
})
|
||||
io.EXPECT().SendMessage(mock.Anything, &protocol.UDPMessage{
|
||||
SessionID: msg2.SessionID,
|
||||
SessionID: msg2_1.SessionID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: msg2.Addr,
|
||||
Addr: msg2_1.Addr,
|
||||
Data: []byte("im fine"),
|
||||
}).Return(nil).Once()
|
||||
msgCh <- msg2
|
||||
msgCh <- msg2_1
|
||||
msgCh <- msg2_2
|
||||
udpConn2Ch <- []byte("im fine")
|
||||
|
||||
msg3 := &protocol.UDPMessage{
|
||||
|
@ -123,7 +136,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
return nil
|
||||
}).Once()
|
||||
eventLogger.EXPECT().Close(msg1.SessionID, nil).Once()
|
||||
eventLogger.EXPECT().Close(msg2.SessionID, nil).Once()
|
||||
eventLogger.EXPECT().Close(msg2_1.SessionID, nil).Once()
|
||||
|
||||
time.Sleep(3 * time.Second) // Wait for timeout
|
||||
mock.AssertExpectationsForObjects(t, io, eventLogger, udpConn1, udpConn2)
|
||||
|
@ -140,6 +153,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
}
|
||||
eventLogger.EXPECT().New(msg4.SessionID, msg4.Addr).Return().Once()
|
||||
udpConn4 := newMockUDPConn(t)
|
||||
io.EXPECT().Hook(msg4.Data, &msg4.Addr).Return(nil).Once()
|
||||
io.EXPECT().UDP(msg4.Addr).Return(udpConn4, nil).Once()
|
||||
udpConn4.EXPECT().WriteTo(msg4.Data, msg4.Addr).Return(12, nil).Once()
|
||||
udpConn4.EXPECT().ReadFrom(mock.Anything).Return(0, "", errUDPClosed).Once()
|
||||
|
@ -161,6 +175,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
Data: []byte("babe i miss you"),
|
||||
}
|
||||
eventLogger.EXPECT().New(msg5.SessionID, msg5.Addr).Return().Once()
|
||||
io.EXPECT().Hook(msg5.Data, &msg5.Addr).Return(nil).Once()
|
||||
io.EXPECT().UDP(msg5.Addr).Return(nil, errUDPIO).Once()
|
||||
eventLogger.EXPECT().Close(msg5.SessionID, errUDPIO).Once()
|
||||
msgCh <- msg5
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue