Fix calculating UDP MTU

This commit is contained in:
世界 2024-04-07 19:50:17 +08:00
parent 6be1f3c03a
commit 766fcba388
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 16 additions and 39 deletions

View file

@ -125,7 +125,6 @@ type udpPacketConn struct {
quicConn quic.Connection quicConn quic.Connection
data chan *udpMessage data chan *udpMessage
udpMTU int udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32 packetId atomic.Uint32
closeOnce sync.Once closeOnce sync.Once
defragger *udpDefragger defragger *udpDefragger
@ -140,6 +139,7 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f
cancel: cancel, cancel: cancel,
quicConn: quicConn, quicConn: quicConn,
data: make(chan *udpMessage, 64), data: make(chan *udpMessage, 64),
udpMTU: 1200,
defragger: newUDPDefragger(), defragger: newUDPDefragger(),
onDestroy: onDestroy, onDestroy: onDestroy,
} }
@ -174,15 +174,6 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
} }
func (c *udpPacketConn) needFragment() bool {
nowTime := time.Now()
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
c.udpMTUTime = nowTime
return true
}
return false
}
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() defer buffer.Release()
select { select {
@ -209,7 +200,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
} }
defer message.releaseMessage() defer message.releaseMessage()
var err error var err error
if c.needFragment() && buffer.Len() > c.udpMTU { if buffer.Len() > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else { } else {
err = c.writePacket(message) err = c.writePacket(message)
@ -221,9 +212,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
if !errors.As(err, &tooLargeErr) { if !errors.As(err, &tooLargeErr) {
return err return err
} }
c.udpMTU = int(tooLargeErr) return c.writePackets(fragUDPMessage(message, int(tooLargeErr)))
c.udpMTUTime = time.Now()
return c.writePackets(fragUDPMessage(message, c.udpMTU))
} }
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
@ -250,7 +239,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
port: destination.Port, port: destination.Port,
data: buf.As(p), data: buf.As(p),
} }
if c.needFragment() && len(p) > c.udpMTU { if len(p) > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil { if err == nil {
return len(p), nil return len(p), nil
@ -265,9 +254,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !errors.As(err, &tooLargeErr) { if !errors.As(err, &tooLargeErr) {
return return
} }
c.udpMTU = int(tooLargeErr) err = c.writePackets(fragUDPMessage(message, int(tooLargeErr)))
c.udpMTUTime = time.Now()
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil { if err == nil {
return len(p), nil return len(p), nil
} }

View file

@ -121,7 +121,6 @@ type udpPacketConn struct {
quicConn quic.Connection quicConn quic.Connection
data chan *udpMessage data chan *udpMessage
udpMTU int udpMTU int
udpMTUTime time.Time
packetId atomic.Uint32 packetId atomic.Uint32
closeOnce sync.Once closeOnce sync.Once
defragger *udpDefragger defragger *udpDefragger
@ -136,6 +135,7 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f
cancel: cancel, cancel: cancel,
quicConn: quicConn, quicConn: quicConn,
data: make(chan *udpMessage, 64), data: make(chan *udpMessage, 64),
udpMTU: 1200,
defragger: newUDPDefragger(), defragger: newUDPDefragger(),
onDestroy: onDestroy, onDestroy: onDestroy,
} }
@ -170,15 +170,6 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
} }
func (c *udpPacketConn) needFragment() bool {
nowTime := time.Now()
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
c.udpMTUTime = nowTime
return true
}
return false
}
func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() defer buffer.Release()
select { select {
@ -204,7 +195,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
} }
defer message.releaseMessage() defer message.releaseMessage()
var err error var err error
if c.needFragment() && buffer.Len() > c.udpMTU { if buffer.Len() > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else { } else {
err = c.writePacket(message) err = c.writePacket(message)
@ -216,9 +207,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
if !errors.As(err, &tooLargeErr) { if !errors.As(err, &tooLargeErr) {
return err return err
} }
c.udpMTU = int(tooLargeErr) return c.writePackets(fragUDPMessage(message, int(tooLargeErr)))
c.udpMTUTime = time.Now()
return c.writePackets(fragUDPMessage(message, c.udpMTU))
} }
func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
@ -243,7 +232,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination: addr.String(), destination: addr.String(),
data: buf.As(p), data: buf.As(p),
} }
if c.needFragment() && len(p) > c.udpMTU { if len(p) > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil { if err == nil {
return len(p), nil return len(p), nil
@ -258,9 +247,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !errors.As(err, &tooLargeErr) { if !errors.As(err, &tooLargeErr) {
return return
} }
c.udpMTU = int(tooLargeErr) err = c.writePackets(fragUDPMessage(message, int(tooLargeErr)))
c.udpMTUTime = time.Now()
err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil { if err == nil {
return len(p), nil return len(p), nil
} }

View file

@ -180,8 +180,11 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
func (c *udpPacketConn) needFragment() bool { func (c *udpPacketConn) needFragment() bool {
if c.udpMTU == 0 {
return false
}
nowTime := time.Now() nowTime := time.Now()
if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { if nowTime.Sub(c.udpMTUTime) < 5*time.Second {
c.udpMTUTime = nowTime c.udpMTUTime = nowTime
return true return true
} }
@ -216,7 +219,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr)
} }
defer message.releaseMessage() defer message.releaseMessage()
var err error var err error
if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU { if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
} else { } else {
err = c.writePacket(message) err = c.writePacket(message)
@ -259,7 +262,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination: destination, destination: destination,
data: buf.As(p), data: buf.As(p),
} }
if !c.udpStream && c.needFragment() && len(p) > c.udpMTU { if !c.udpStream && c.needFragment() && len(p) > c.udpMTU-message.headerSize() {
err = c.writePackets(fragUDPMessage(message, c.udpMTU)) err = c.writePackets(fragUDPMessage(message, c.udpMTU))
if err == nil { if err == nil {
return len(p), nil return len(p), nil