mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 03:47:38 +03:00
Fix calculate mtu
This commit is contained in:
parent
cb9b17d6a4
commit
3483762200
3 changed files with 50 additions and 32 deletions
|
@ -4,8 +4,6 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
@ -23,17 +21,34 @@ func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return Copy(w, r.Reader)
|
||||
}
|
||||
|
||||
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
|
||||
_, isTCPConn := dst.(*net.TCPConn)
|
||||
if !isTCPConn {
|
||||
return false
|
||||
}
|
||||
switch src.(type) {
|
||||
case *net.TCPConn, *net.UnixConn, *os.File:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
func (r *readOnlyReader) Upstream() any {
|
||||
return r.Reader
|
||||
}
|
||||
|
||||
func (r *readOnlyReader) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type writeOnlyWriter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
return Copy(w.Writer, r)
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) Upstream() any {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *writeOnlyWriter) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func needWrapper(src, dst any) bool {
|
||||
_, srcTCPConn := src.(*net.TCPConn)
|
||||
_, dstTCPConn := dst.(*net.TCPConn)
|
||||
return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn)
|
||||
}
|
||||
|
||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
|
@ -45,10 +60,13 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
|||
src = N.UnwrapReader(src)
|
||||
dst = N.UnwrapWriter(dst)
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
if needWrapper(dst, src) {
|
||||
dst = &writeOnlyWriter{dst}
|
||||
}
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
if needReadFromWrapper(rt, src) {
|
||||
if needWrapper(rt, src) {
|
||||
src = &readOnlyReader{src}
|
||||
}
|
||||
return rt.ReadFrom(src)
|
||||
|
@ -86,7 +104,7 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
|
|||
frontHeadroom := N.CalculateFrontHeadroom(dst)
|
||||
rearHeadroom := N.CalculateRearHeadroom(dst)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
var notFirstTime bool
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
|
@ -143,7 +161,7 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
|
|||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
err = src.ReadBuffer(readBuffer)
|
||||
if err != nil {
|
||||
|
@ -235,7 +253,7 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
|||
var destination M.Socksaddr
|
||||
var notFirstTime bool
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
for {
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
|
@ -293,7 +311,7 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
|
|||
for {
|
||||
buffer := buf.NewSize(bufferSize)
|
||||
readBufferRaw := buffer.Slice()
|
||||
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
|
||||
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
|
||||
readBuffer.Resize(frontHeadroom, 0)
|
||||
destination, err = src.ReadPacket(readBuffer)
|
||||
if err != nil {
|
||||
|
|
|
@ -136,13 +136,13 @@ func calculateWriterMTU(writer any) int {
|
|||
}
|
||||
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
|
||||
upstreamMTU := calculateWriterMTU(upstream.Upstream())
|
||||
if mtu == 0 && upstreamMTU < mtu {
|
||||
if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
|
||||
upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter())
|
||||
if mtu == 0 && upstreamMTU < mtu {
|
||||
if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu {
|
||||
mtu = upstreamMTU
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
)
|
||||
|
||||
|
@ -25,7 +26,7 @@ const (
|
|||
var CRLF = []byte{'\r', '\n'}
|
||||
|
||||
type ClientConn struct {
|
||||
net.Conn
|
||||
N.ExtendedConn
|
||||
key [KeyLength]byte
|
||||
destination M.Socksaddr
|
||||
headerWritten bool
|
||||
|
@ -33,17 +34,17 @@ type ClientConn struct {
|
|||
|
||||
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
|
||||
return &ClientConn{
|
||||
Conn: conn,
|
||||
key: key,
|
||||
destination: destination,
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
key: key,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) Write(p []byte) (n int, err error) {
|
||||
if c.headerWritten {
|
||||
return c.Conn.Write(p)
|
||||
return c.ExtendedConn.Write(p)
|
||||
}
|
||||
err = ClientHandshake(c.Conn, c.key, c.destination, p)
|
||||
err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -54,10 +55,9 @@ func (c *ClientConn) Write(p []byte) (n int, err error) {
|
|||
|
||||
func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
if c.headerWritten {
|
||||
defer buffer.Release()
|
||||
return common.Error(c.Conn.Write(buffer.Bytes()))
|
||||
return c.ExtendedConn.WriteBuffer(buffer)
|
||||
}
|
||||
err := ClientHandshakeBuffer(c.Conn, c.key, c.destination, buffer)
|
||||
err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -69,11 +69,11 @@ func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
if !c.headerWritten {
|
||||
return bufio.ReadFrom0(c, r)
|
||||
}
|
||||
return bufio.Copy(c.Conn, r)
|
||||
return bufio.Copy(c.ExtendedConn, r)
|
||||
}
|
||||
|
||||
func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return bufio.Copy(w, c.Conn)
|
||||
return bufio.Copy(w, c.ExtendedConn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) FrontHeadroom() int {
|
||||
|
@ -84,7 +84,7 @@ func (c *ClientConn) FrontHeadroom() int {
|
|||
}
|
||||
|
||||
func (c *ClientConn) Upstream() any {
|
||||
return c.Conn
|
||||
return c.ExtendedConn
|
||||
}
|
||||
|
||||
type ClientPacketConn struct {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue