fix: quic sniff not work if udp msg fragmentated

This commit is contained in:
Haruue 2024-09-22 22:48:06 +08:00
parent 21ea2a024a
commit 4ecbd57294
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148

View file

@ -31,11 +31,62 @@ type udpEventLogger interface {
type udpSessionEntry struct {
ID uint32
Conn UDPConn
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
OriginalAddr string // The original address in the UDP message
D *frag.Defragger
Last *utils.AtomicTime
Timeout bool // true if the session is closed due to timeout
IO udpIO
DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error)
ExitFunc func(err error)
timeoutChan chan struct{}
exitChan chan error
conn UDPConn
connLock sync.Mutex
closed bool
}
func newUDPSessionEntry(
id uint32, io udpIO,
dialFunc func(string, []byte) (UDPConn, string, error),
exitFunc func(error),
) (e *udpSessionEntry) {
e = &udpSessionEntry{
ID: id,
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
IO: io,
DialFunc: dialFunc,
ExitFunc: exitFunc,
timeoutChan: make(chan struct{}),
exitChan: make(chan error, 2),
}
go func() {
// Guard routine
var err error
select {
case <-e.timeoutChan:
// Use nil error to indicate timeout.
case err = <-e.exitChan:
}
// We need this lock to ensure not to create conn after session exit
e.connLock.Lock()
e.closed = true
if e.conn != nil {
_ = e.conn.Close()
}
e.connLock.Unlock()
e.ExitFunc(err)
}()
return
}
// Feed feeds a UDP message to the session.
@ -49,27 +100,72 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
if dfMsg == nil {
return 0, nil
}
if e.OverrideAddr != "" {
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
} else {
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
if e.conn == nil {
err := e.initConn(dfMsg)
if err != nil {
return 0, err
}
}
addr := dfMsg.Addr
if e.OverrideAddr != "" {
addr = e.OverrideAddr
}
return e.conn.WriteTo(dfMsg.Data, addr)
}
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
// and sends using the provided io.
// Exit and returns error when either the underlying UDP connection returns
// error (e.g. closed), or the provided io returns error when sending.
func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
// initConn initializes the UDP connection of the session.
// If no error is returned, the e.conn is set to the new connection.
func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error {
// We need this lock to ensure not to create conn after session exit
e.connLock.Lock()
defer e.connLock.Unlock()
if e.closed {
return errors.New("session is closed")
}
conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data)
if err != nil {
// Fail fast if DailFunc failed
// (usually indicates the connection has been rejected by the ACL)
e.exitChan <- err
return err
}
e.conn = conn
if firstMsg.Addr != actualAddr {
e.OverrideAddr = actualAddr
e.OriginalAddr = firstMsg.Addr
}
go e.receiveLoop()
return nil
}
// receiveLoop receives incoming UDP packets, packs them into UDP messages,
// and sends using the IO.
// Exit when either the underlying UDP connection returns error (e.g. closed),
// or the IO returns error when sending.
func (e *udpSessionEntry) receiveLoop() {
udpBuf := make([]byte, protocol.MaxUDPSize)
msgBuf := make([]byte, protocol.MaxUDPSize)
for {
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
udpN, rAddr, err := e.conn.ReadFrom(udpBuf)
if err != nil {
return err
e.exitChan <- err
return
}
e.Last.Set(time.Now())
if e.OriginalAddr != "" {
// Use the original address in the opposite direction,
// otherwise the QUIC clients or NAT on the client side
// may not treat it as the same UDP session.
rAddr = e.OriginalAddr
}
msg := &protocol.UDPMessage{
SessionID: e.ID,
PacketID: 0,
@ -78,13 +174,23 @@ func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
Addr: rAddr,
Data: udpBuf[:udpN],
}
err = sendMessageAutoFrag(io, msgBuf, msg)
err = sendMessageAutoFrag(e.IO, msgBuf, msg)
if err != nil {
return err
e.exitChan <- err
return
}
}
}
// MarkTimeout marks the session to be cleaned up due to timeout.
// Should only be called by the cleanup routine of the session manager.
func (e *udpSessionEntry) MarkTimeout() {
select {
case e.timeoutChan <- struct{}{}:
default:
}
}
// sendMessageAutoFrag tries to send a UDP message as a whole first,
// but if it fails due to quic.ErrMessageTooLarge, it tries again by
// fragmenting the message.
@ -168,10 +274,8 @@ func (m *udpSessionManager) cleanup(idleOnly bool) {
now := time.Now()
for _, entry := range m.m {
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
entry.Timeout = true
_ = entry.Conn.Close()
// Closing the connection here will cause the ReceiveLoop to exit,
// and the session will be removed from the map there.
entry.MarkTimeout()
// Entry will be removed by its ExitFunc.
}
}
}
@ -183,47 +287,31 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
// Create a new session if not exists
if entry == nil {
// Call the hook
origMsgAddr := msg.Addr
err := m.io.Hook(msg.Data, &msg.Addr)
if err != nil {
return
}
// Log the event
m.eventLogger.New(msg.SessionID, msg.Addr)
// Dial target & create a new session entry
conn, err := m.io.UDP(msg.Addr)
if err != nil {
m.eventLogger.Close(msg.SessionID, err)
return
}
entry = &udpSessionEntry{
ID: msg.SessionID,
Conn: conn,
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
}
if origMsgAddr != msg.Addr {
// Hook changed the address, enable address override
entry.OverrideAddr = msg.Addr
}
// Start the receive loop for this session
go func() {
err := entry.ReceiveLoop(m.io)
if !entry.Timeout {
_ = entry.Conn.Close()
m.eventLogger.Close(entry.ID, err)
} else {
// Connection already closed by timeout cleanup,
// no need to close again here.
// Use nil error to indicate timeout.
m.eventLogger.Close(entry.ID, nil)
dialFunc := func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) {
// Call the hook
err = m.io.Hook(firstMsgData, &addr)
if err != nil {
return
}
actualAddr = addr
// Log the event
m.eventLogger.New(msg.SessionID, addr)
// Dial target
conn, err = m.io.UDP(addr)
return
}
exitFunc := func(err error) {
// Log the event
m.eventLogger.Close(entry.ID, err)
// Remove the session from the map
m.mutex.Lock()
delete(m.m, entry.ID)
m.mutex.Unlock()
}()
}
entry = newUDPSessionEntry(msg.SessionID, m.io, dialFunc, exitFunc)
// Insert the session into the map
m.mutex.Lock()
m.m[msg.SessionID] = entry