From 766fcba38810b04499c028a20f993ab393b9fc84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 7 Apr 2024 19:50:17 +0800 Subject: [PATCH] Fix calculating UDP MTU --- hysteria/packet.go | 23 +++++------------------ hysteria2/packet.go | 23 +++++------------------ tuic/packet.go | 9 ++++++--- 3 files changed, 16 insertions(+), 39 deletions(-) diff --git a/hysteria/packet.go b/hysteria/packet.go index a2689c9..a597c76 100644 --- a/hysteria/packet.go +++ b/hysteria/packet.go @@ -125,7 +125,6 @@ type udpPacketConn struct { quicConn quic.Connection data chan *udpMessage udpMTU int - udpMTUTime time.Time packetId atomic.Uint32 closeOnce sync.Once defragger *udpDefragger @@ -140,6 +139,7 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f cancel: cancel, quicConn: quicConn, data: make(chan *udpMessage, 64), + udpMTU: 1200, defragger: newUDPDefragger(), onDestroy: onDestroy, } @@ -174,15 +174,6 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (c *udpPacketConn) needFragment() bool { - nowTime := time.Now() - if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { - c.udpMTUTime = nowTime - return true - } - return false -} - func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() select { @@ -209,7 +200,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) } defer message.releaseMessage() var err error - if c.needFragment() && buffer.Len() > c.udpMTU { + if buffer.Len() > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) } else { err = c.writePacket(message) @@ -221,9 +212,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) if !errors.As(err, &tooLargeErr) { return err } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - return c.writePackets(fragUDPMessage(message, c.udpMTU)) + return c.writePackets(fragUDPMessage(message, int(tooLargeErr))) } func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { @@ -250,7 +239,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { port: destination.Port, data: buf.As(p), } - if c.needFragment() && len(p) > c.udpMTU { + if len(p) > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) if err == nil { return len(p), nil @@ -265,9 +254,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !errors.As(err, &tooLargeErr) { return } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + err = c.writePackets(fragUDPMessage(message, int(tooLargeErr))) if err == nil { return len(p), nil } diff --git a/hysteria2/packet.go b/hysteria2/packet.go index f577b15..e3c1f1c 100644 --- a/hysteria2/packet.go +++ b/hysteria2/packet.go @@ -121,7 +121,6 @@ type udpPacketConn struct { quicConn quic.Connection data chan *udpMessage udpMTU int - udpMTUTime time.Time packetId atomic.Uint32 closeOnce sync.Once defragger *udpDefragger @@ -136,6 +135,7 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f cancel: cancel, quicConn: quicConn, data: make(chan *udpMessage, 64), + udpMTU: 1200, defragger: newUDPDefragger(), onDestroy: onDestroy, } @@ -170,15 +170,6 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (c *udpPacketConn) needFragment() bool { - nowTime := time.Now() - if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { - c.udpMTUTime = nowTime - return true - } - return false -} - func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() select { @@ -204,7 +195,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) } defer message.releaseMessage() var err error - if c.needFragment() && buffer.Len() > c.udpMTU { + if buffer.Len() > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) } else { err = c.writePacket(message) @@ -216,9 +207,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) if !errors.As(err, &tooLargeErr) { return err } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - return c.writePackets(fragUDPMessage(message, c.udpMTU)) + return c.writePackets(fragUDPMessage(message, int(tooLargeErr))) } func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { @@ -243,7 +232,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination: addr.String(), data: buf.As(p), } - if c.needFragment() && len(p) > c.udpMTU { + if len(p) > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) if err == nil { return len(p), nil @@ -258,9 +247,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !errors.As(err, &tooLargeErr) { return } - c.udpMTU = int(tooLargeErr) - c.udpMTUTime = time.Now() - err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + err = c.writePackets(fragUDPMessage(message, int(tooLargeErr))) if err == nil { return len(p), nil } diff --git a/tuic/packet.go b/tuic/packet.go index 6eb07bf..888695f 100644 --- a/tuic/packet.go +++ b/tuic/packet.go @@ -180,8 +180,11 @@ func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *udpPacketConn) needFragment() bool { + if c.udpMTU == 0 { + return false + } nowTime := time.Now() - if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { + if nowTime.Sub(c.udpMTUTime) < 5*time.Second { c.udpMTUTime = nowTime return true } @@ -216,7 +219,7 @@ func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) } defer message.releaseMessage() var err error - if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU { + if !c.udpStream && c.needFragment() && buffer.Len() > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) } else { err = c.writePacket(message) @@ -259,7 +262,7 @@ func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination: destination, data: buf.As(p), } - if !c.udpStream && c.needFragment() && len(p) > c.udpMTU { + if !c.udpStream && c.needFragment() && len(p) > c.udpMTU-message.headerSize() { err = c.writePackets(fragUDPMessage(message, c.udpMTU)) if err == nil { return len(p), nil