Fix calculate mtu

This commit is contained in:
世界 2022-09-30 21:59:34 +08:00
parent cb9b17d6a4
commit 3483762200
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 50 additions and 32 deletions

View file

@ -4,8 +4,6 @@ import (
"context"
"io"
"net"
"os"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
@ -23,17 +21,34 @@ func (r *readOnlyReader) WriteTo(w io.Writer) (n int64, err error) {
return Copy(w, r.Reader)
}
func needReadFromWrapper(dst io.ReaderFrom, src io.Reader) bool {
_, isTCPConn := dst.(*net.TCPConn)
if !isTCPConn {
return false
}
switch src.(type) {
case *net.TCPConn, *net.UnixConn, *os.File:
return false
default:
return true
}
func (r *readOnlyReader) Upstream() any {
return r.Reader
}
func (r *readOnlyReader) ReaderReplaceable() bool {
return true
}
type writeOnlyWriter struct {
io.Writer
}
func (w *writeOnlyWriter) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(w.Writer, r)
}
func (w *writeOnlyWriter) Upstream() any {
return w.Writer
}
func (w *writeOnlyWriter) WriterReplaceable() bool {
return true
}
func needWrapper(src, dst any) bool {
_, srcTCPConn := src.(*net.TCPConn)
_, dstTCPConn := dst.(*net.TCPConn)
return (srcTCPConn || dstTCPConn) && !(srcTCPConn && dstTCPConn)
}
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
@ -45,10 +60,13 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
if needWrapper(dst, src) {
dst = &writeOnlyWriter{dst}
}
return wt.WriteTo(dst)
}
if rt, ok := dst.(io.ReaderFrom); ok {
if needReadFromWrapper(rt, src) {
if needWrapper(rt, src) {
src = &readOnlyReader{src}
}
return rt.ReadFrom(src)
@ -86,7 +104,7 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
@ -143,7 +161,7 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = src.ReadBuffer(readBuffer)
if err != nil {
@ -235,7 +253,7 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
var destination M.Socksaddr
var notFirstTime bool
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
for {
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
@ -293,7 +311,7 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = src.ReadPacket(readBuffer)
if err != nil {

View file

@ -136,13 +136,13 @@ func calculateWriterMTU(writer any) int {
}
if upstream, hasUpstream := writer.(common.WithUpstream); hasUpstream {
upstreamMTU := calculateWriterMTU(upstream.Upstream())
if mtu == 0 && upstreamMTU < mtu {
if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu {
mtu = upstreamMTU
}
}
if upstream, hasUpstream := writer.(WithUpstreamWriter); hasUpstream {
upstreamMTU := calculateWriterMTU(upstream.UpstreamWriter())
if mtu == 0 && upstreamMTU < mtu {
if mtu == 0 || upstreamMTU > 0 && upstreamMTU < mtu {
mtu = upstreamMTU
}
}

View file

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
@ -25,7 +26,7 @@ const (
var CRLF = []byte{'\r', '\n'}
type ClientConn struct {
net.Conn
N.ExtendedConn
key [KeyLength]byte
destination M.Socksaddr
headerWritten bool
@ -33,17 +34,17 @@ type ClientConn struct {
func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
return &ClientConn{
Conn: conn,
key: key,
destination: destination,
ExtendedConn: bufio.NewExtendedConn(conn),
key: key,
destination: destination,
}
}
func (c *ClientConn) Write(p []byte) (n int, err error) {
if c.headerWritten {
return c.Conn.Write(p)
return c.ExtendedConn.Write(p)
}
err = ClientHandshake(c.Conn, c.key, c.destination, p)
err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
if err != nil {
return
}
@ -54,10 +55,9 @@ func (c *ClientConn) Write(p []byte) (n int, err error) {
func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
if c.headerWritten {
defer buffer.Release()
return common.Error(c.Conn.Write(buffer.Bytes()))
return c.ExtendedConn.WriteBuffer(buffer)
}
err := ClientHandshakeBuffer(c.Conn, c.key, c.destination, buffer)
err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
if err != nil {
return err
}
@ -69,11 +69,11 @@ func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
if !c.headerWritten {
return bufio.ReadFrom0(c, r)
}
return bufio.Copy(c.Conn, r)
return bufio.Copy(c.ExtendedConn, r)
}
func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
return bufio.Copy(w, c.Conn)
return bufio.Copy(w, c.ExtendedConn)
}
func (c *ClientConn) FrontHeadroom() int {
@ -84,7 +84,7 @@ func (c *ClientConn) FrontHeadroom() int {
}
func (c *ClientConn) Upstream() any {
return c.Conn
return c.ExtendedConn
}
type ClientPacketConn struct {