diff --git a/common/cond.go b/common/cond.go index 6cf44cf..b8aa731 100644 --- a/common/cond.go +++ b/common/cond.go @@ -147,9 +147,9 @@ func Close(closers ...any) error { } switch c := closer.(type) { case ReaderWithUpstream: - err = Close(c.Upstream()) + err = Close(c.UpstreamReader()) case WriterWithUpstream: - err = Close(c.Upstream()) + err = Close(c.UpstreamWriter()) } if err != nil { retErr = err diff --git a/common/flush.go b/common/flush.go index bdbc381..db004ba 100644 --- a/common/flush.go +++ b/common/flush.go @@ -18,16 +18,16 @@ func Flush(writer io.Writer) error { } } if u, ok := writer.(WriterWithUpstream); ok { - if u.Replaceable() { + if u.WriterReplaceable() { if writerBack == writer { - } else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter { + } else if setter, hasSetter := u.UpstreamWriter().(UpstreamWriterSetter); hasSetter { setter.SetWriter(writerBack) - writer = u.Upstream() + writer = u.UpstreamWriter() continue } } writerBack = writer - writer = u.Upstream() + writer = u.UpstreamWriter() } else { break } @@ -46,20 +46,20 @@ func FlushVar(writerP *io.Writer) error { } } if u, ok := writer.(WriterWithUpstream); ok { - if u.Replaceable() { + if u.WriterReplaceable() { if writerBack == writer { - writer = u.Upstream() + writer = u.UpstreamWriter() writerBack = writer *writerP = writer continue } else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter { - setter.SetWriter(u.Upstream()) - writer = u.Upstream() + setter.SetWriter(u.UpstreamWriter()) + writer = u.UpstreamWriter() continue } } writerBack = writer - writer = u.Upstream() + writer = u.UpstreamWriter() } else { break } @@ -72,11 +72,11 @@ type FlushOnceWriter struct { flushed bool } -func (w *FlushOnceWriter) Upstream() io.Writer { +func (w *FlushOnceWriter) UpstreamWriter() io.Writer { return w.Writer } -func (w *FlushOnceWriter) Replaceable() bool { +func (w *FlushOnceWriter) WriterReplaceable() bool { return w.flushed } diff --git a/common/rw/buffer.go b/common/rw/buffer.go index 0b60ff0..ac22ebe 100644 --- a/common/rw/buffer.go +++ b/common/rw/buffer.go @@ -12,11 +12,11 @@ type BufferedWriter struct { Buffer *buf.Buffer } -func (w *BufferedWriter) Upstream() io.Writer { +func (w *BufferedWriter) UpstreamWriter() io.Writer { return w.Writer } -func (w *BufferedWriter) Replaceable() bool { +func (w *BufferedWriter) WriterReplaceable() bool { return w.Buffer == nil } @@ -85,11 +85,11 @@ type HeaderWriter struct { Header *buf.Buffer } -func (w *HeaderWriter) Upstream() io.Writer { +func (w *HeaderWriter) UpstreamWriter() io.Writer { return w.Writer } -func (w *HeaderWriter) Replaceable() bool { +func (w *HeaderWriter) WriterReplaceable() bool { return w.Header == nil } diff --git a/common/rw/copy.go b/common/rw/copy.go index bc0827c..0472c82 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -25,13 +25,13 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) { } } if u, ok := writer.(common.WriterWithUpstream); ok { - if u.Replaceable() && writerBack == writer { - writer = u.Upstream() + if u.WriterReplaceable() && writerBack == writer { + writer = u.UpstreamWriter() writerBack = writer writerVar = &writer continue } - writer = u.Upstream() + writer = u.UpstreamWriter() writerBack = writer } else { break diff --git a/common/rw/duplex.go b/common/rw/duplex.go index 5c809b5..eaaf77d 100644 --- a/common/rw/duplex.go +++ b/common/rw/duplex.go @@ -1,6 +1,8 @@ package rw -import "io" +import ( + "github.com/sagernet/sing/common" +) type ReadCloser interface { CloseRead() error @@ -10,16 +12,32 @@ type WriteCloser interface { CloseWrite() error } -func CloseRead(conn io.Closer) error { - if closer, ok := conn.(ReadCloser); ok { - return closer.CloseRead() +func CloseRead(reader any) error { + r := reader + for { + if closer, ok := r.(ReadCloser); ok { + return closer.CloseRead() + } + if u, ok := r.(common.ReaderWithUpstream); ok { + r = u.UpstreamReader() + continue + } + break } - return nil + return common.Close(reader) } -func CloseWrite(conn io.Closer) error { - if closer, ok := conn.(WriteCloser); ok { - return closer.CloseWrite() +func CloseWrite(writer any) error { + w := writer + for { + if closer, ok := w.(WriteCloser); ok { + return closer.CloseWrite() + } + if u, ok := w.(common.WriterWithUpstream); ok { + w = u.UpstreamWriter() + continue + } + break } - return nil + return common.Close(writer) } diff --git a/common/upstream.go b/common/upstream.go index 7bad06f..e9e2f2e 100644 --- a/common/upstream.go +++ b/common/upstream.go @@ -5,17 +5,17 @@ import ( ) type ReaderWithUpstream interface { - Upstream() io.Reader - Replaceable() bool + UpstreamReader() io.Reader + ReaderReplaceable() bool } type UpstreamReaderSetter interface { - SetUpstream(reader io.Reader) + SetReader(reader io.Reader) } type WriterWithUpstream interface { - Upstream() io.Writer - Replaceable() bool + UpstreamWriter() io.Writer + WriterReplaceable() bool } type UpstreamWriterSetter interface { diff --git a/protocol/shadowsocks/shadowaead/aead.go b/protocol/shadowsocks/shadowaead/aead.go index de3f08a..40127df 100644 --- a/protocol/shadowsocks/shadowaead/aead.go +++ b/protocol/shadowsocks/shadowaead/aead.go @@ -41,15 +41,15 @@ func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce [ } } -func (r *Reader) Upstream() io.Reader { +func (r *Reader) UpstreamReader() io.Reader { return r.upstream } -func (r *Reader) Replaceable() bool { +func (r *Reader) ReaderReplaceable() bool { return false } -func (r *Reader) SetUpstream(reader io.Reader) { +func (r *Reader) SetReader(reader io.Reader) { r.upstream = reader } @@ -226,11 +226,11 @@ func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buf } } -func (w *Writer) Upstream() io.Writer { +func (w *Writer) UpstreamWriter() io.Writer { return w.upstream } -func (w *Writer) Replaceable() bool { +func (w *Writer) WriterReplaceable() bool { return false } @@ -299,11 +299,11 @@ type BufferedWriter struct { index int } -func (w *BufferedWriter) Upstream() io.Writer { +func (w *BufferedWriter) UpstreamWriter() io.Writer { return w.upstream } -func (w *BufferedWriter) Replaceable() bool { +func (w *BufferedWriter) WriterReplaceable() bool { return w.index == 0 } diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 5867417..340409b 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -298,6 +298,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { return c.writer.ReadFrom(r) } +func (c *clientConn) UpstreamReader() io.Reader { + if c.reader == nil { + return c.Conn + } + return c.reader +} + +func (c *clientConn) ReaderReplaceable() bool { + return c.reader != nil +} + +func (c *clientConn) UpstreamWriter() io.Writer { + if c.writer == nil { + return c.Conn + } + return c.writer +} + +func (c *clientConn) WriterReplaceable() bool { + return c.writer != nil +} + type clientPacketConn struct { *Method net.Conn @@ -329,3 +351,19 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { } return socks.AddressSerializer.ReadAddrPort(buffer) } + +func (c *clientPacketConn) UpstreamReader() io.Reader { + return c.Conn +} + +func (c *clientPacketConn) ReaderReplaceable() bool { + return false +} + +func (c *clientPacketConn) UpstreamWriter() io.Writer { + return c.Conn +} + +func (c *clientPacketConn) WriterReplaceable() bool { + return false +} diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index a6c6eb2..9ac1b7e 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -168,22 +168,34 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { return c.reader.WriteTo(w) } +func (c *serverConn) UpstreamReader() io.Reader { + if c.reader == nil { + return c.Conn + } + return c.reader +} + +func (c *serverConn) ReaderReplaceable() bool { + return c.reader != nil +} + +func (c *serverConn) UpstreamWriter() io.Writer { + if c.writer == nil { + return c.Conn + } + return c.writer +} + +func (c *serverConn) WriterReplaceable() bool { + return c.writer != nil +} + func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { if buffer.Len() < s.keySaltLength { return E.New("bad packet") } key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) c := s.constructor(common.Dup(key)) - /*data := buf.New() - packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) - if err != nil { - return err - } - data.Truncate(len(packet)) - metadata.Protocol = "shadowsocks" - return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { - return &serverPacketWriter{s, conn, metadata.Source} - }, data, metadata)*/ packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) if err != nil { return err diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 56b0928..a514e42 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -377,6 +377,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { return c.writer.ReadFrom(r) } +func (c *clientConn) UpstreamReader() io.Reader { + if c.reader == nil { + return c.Conn + } + return c.reader +} + +func (c *clientConn) ReaderReplaceable() bool { + return c.reader != nil +} + +func (c *clientConn) UpstreamWriter() io.Writer { + if c.writer == nil { + return c.Conn + } + return c.writer +} + +func (c *clientConn) WriterReplaceable() bool { + return c.writer != nil +} + type clientPacketConn struct { net.Conn method *Method @@ -582,3 +604,19 @@ func (m *Method) newUDPSession() *udpSession { } return session } + +func (c *clientPacketConn) UpstreamReader() io.Reader { + return c.Conn +} + +func (c *clientPacketConn) ReaderReplaceable() bool { + return false +} + +func (c *clientPacketConn) UpstreamWriter() io.Writer { + return c.Conn +} + +func (c *clientPacketConn) WriterReplaceable() bool { + return false +} diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index 30e48a9..2e7dac3 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -207,6 +207,28 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { return c.reader.WriteTo(w) } +func (c *serverConn) UpstreamReader() io.Reader { + if c.reader == nil { + return c.Conn + } + return c.reader +} + +func (c *serverConn) ReaderReplaceable() bool { + return c.reader != nil +} + +func (c *serverConn) UpstreamWriter() io.Writer { + if c.writer == nil { + return c.Conn + } + return c.writer +} + +func (c *serverConn) WriterReplaceable() bool { + return c.writer != nil +} + func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { var packetHeader []byte if s.udpCipher != nil {