Fix close conn

This commit is contained in:
世界 2022-05-01 07:53:26 +08:00
parent a5d5d79e29
commit 118c423774
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 179 additions and 51 deletions

View file

@ -147,9 +147,9 @@ func Close(closers ...any) error {
} }
switch c := closer.(type) { switch c := closer.(type) {
case ReaderWithUpstream: case ReaderWithUpstream:
err = Close(c.Upstream()) err = Close(c.UpstreamReader())
case WriterWithUpstream: case WriterWithUpstream:
err = Close(c.Upstream()) err = Close(c.UpstreamWriter())
} }
if err != nil { if err != nil {
retErr = err retErr = err

View file

@ -18,16 +18,16 @@ func Flush(writer io.Writer) error {
} }
} }
if u, ok := writer.(WriterWithUpstream); ok { if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() { if u.WriterReplaceable() {
if writerBack == writer { if writerBack == writer {
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter { } else if setter, hasSetter := u.UpstreamWriter().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack) setter.SetWriter(writerBack)
writer = u.Upstream() writer = u.UpstreamWriter()
continue continue
} }
} }
writerBack = writer writerBack = writer
writer = u.Upstream() writer = u.UpstreamWriter()
} else { } else {
break break
} }
@ -46,20 +46,20 @@ func FlushVar(writerP *io.Writer) error {
} }
} }
if u, ok := writer.(WriterWithUpstream); ok { if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() { if u.WriterReplaceable() {
if writerBack == writer { if writerBack == writer {
writer = u.Upstream() writer = u.UpstreamWriter()
writerBack = writer writerBack = writer
*writerP = writer *writerP = writer
continue continue
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter { } else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
setter.SetWriter(u.Upstream()) setter.SetWriter(u.UpstreamWriter())
writer = u.Upstream() writer = u.UpstreamWriter()
continue continue
} }
} }
writerBack = writer writerBack = writer
writer = u.Upstream() writer = u.UpstreamWriter()
} else { } else {
break break
} }
@ -72,11 +72,11 @@ type FlushOnceWriter struct {
flushed bool flushed bool
} }
func (w *FlushOnceWriter) Upstream() io.Writer { func (w *FlushOnceWriter) UpstreamWriter() io.Writer {
return w.Writer return w.Writer
} }
func (w *FlushOnceWriter) Replaceable() bool { func (w *FlushOnceWriter) WriterReplaceable() bool {
return w.flushed return w.flushed
} }

View file

@ -12,11 +12,11 @@ type BufferedWriter struct {
Buffer *buf.Buffer Buffer *buf.Buffer
} }
func (w *BufferedWriter) Upstream() io.Writer { func (w *BufferedWriter) UpstreamWriter() io.Writer {
return w.Writer return w.Writer
} }
func (w *BufferedWriter) Replaceable() bool { func (w *BufferedWriter) WriterReplaceable() bool {
return w.Buffer == nil return w.Buffer == nil
} }
@ -85,11 +85,11 @@ type HeaderWriter struct {
Header *buf.Buffer Header *buf.Buffer
} }
func (w *HeaderWriter) Upstream() io.Writer { func (w *HeaderWriter) UpstreamWriter() io.Writer {
return w.Writer return w.Writer
} }
func (w *HeaderWriter) Replaceable() bool { func (w *HeaderWriter) WriterReplaceable() bool {
return w.Header == nil return w.Header == nil
} }

View file

@ -25,13 +25,13 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
} }
} }
if u, ok := writer.(common.WriterWithUpstream); ok { if u, ok := writer.(common.WriterWithUpstream); ok {
if u.Replaceable() && writerBack == writer { if u.WriterReplaceable() && writerBack == writer {
writer = u.Upstream() writer = u.UpstreamWriter()
writerBack = writer writerBack = writer
writerVar = &writer writerVar = &writer
continue continue
} }
writer = u.Upstream() writer = u.UpstreamWriter()
writerBack = writer writerBack = writer
} else { } else {
break break

View file

@ -1,6 +1,8 @@
package rw package rw
import "io" import (
"github.com/sagernet/sing/common"
)
type ReadCloser interface { type ReadCloser interface {
CloseRead() error CloseRead() error
@ -10,16 +12,32 @@ type WriteCloser interface {
CloseWrite() error CloseWrite() error
} }
func CloseRead(conn io.Closer) error { func CloseRead(reader any) error {
if closer, ok := conn.(ReadCloser); ok { r := reader
for {
if closer, ok := r.(ReadCloser); ok {
return closer.CloseRead() return closer.CloseRead()
} }
return nil if u, ok := r.(common.ReaderWithUpstream); ok {
r = u.UpstreamReader()
continue
}
break
}
return common.Close(reader)
} }
func CloseWrite(conn io.Closer) error { func CloseWrite(writer any) error {
if closer, ok := conn.(WriteCloser); ok { w := writer
for {
if closer, ok := w.(WriteCloser); ok {
return closer.CloseWrite() return closer.CloseWrite()
} }
return nil if u, ok := w.(common.WriterWithUpstream); ok {
w = u.UpstreamWriter()
continue
}
break
}
return common.Close(writer)
} }

View file

@ -5,17 +5,17 @@ import (
) )
type ReaderWithUpstream interface { type ReaderWithUpstream interface {
Upstream() io.Reader UpstreamReader() io.Reader
Replaceable() bool ReaderReplaceable() bool
} }
type UpstreamReaderSetter interface { type UpstreamReaderSetter interface {
SetUpstream(reader io.Reader) SetReader(reader io.Reader)
} }
type WriterWithUpstream interface { type WriterWithUpstream interface {
Upstream() io.Writer UpstreamWriter() io.Writer
Replaceable() bool WriterReplaceable() bool
} }
type UpstreamWriterSetter interface { type UpstreamWriterSetter interface {

View file

@ -41,15 +41,15 @@ func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce [
} }
} }
func (r *Reader) Upstream() io.Reader { func (r *Reader) UpstreamReader() io.Reader {
return r.upstream return r.upstream
} }
func (r *Reader) Replaceable() bool { func (r *Reader) ReaderReplaceable() bool {
return false return false
} }
func (r *Reader) SetUpstream(reader io.Reader) { func (r *Reader) SetReader(reader io.Reader) {
r.upstream = reader r.upstream = reader
} }
@ -226,11 +226,11 @@ func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buf
} }
} }
func (w *Writer) Upstream() io.Writer { func (w *Writer) UpstreamWriter() io.Writer {
return w.upstream return w.upstream
} }
func (w *Writer) Replaceable() bool { func (w *Writer) WriterReplaceable() bool {
return false return false
} }
@ -299,11 +299,11 @@ type BufferedWriter struct {
index int index int
} }
func (w *BufferedWriter) Upstream() io.Writer { func (w *BufferedWriter) UpstreamWriter() io.Writer {
return w.upstream return w.upstream
} }
func (w *BufferedWriter) Replaceable() bool { func (w *BufferedWriter) WriterReplaceable() bool {
return w.index == 0 return w.index == 0
} }

View file

@ -298,6 +298,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r) return c.writer.ReadFrom(r)
} }
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
}
type clientPacketConn struct { type clientPacketConn struct {
*Method *Method
net.Conn net.Conn
@ -329,3 +351,19 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
} }
return socks.AddressSerializer.ReadAddrPort(buffer) return socks.AddressSerializer.ReadAddrPort(buffer)
} }
func (c *clientPacketConn) UpstreamReader() io.Reader {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -168,22 +168,34 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w) return c.reader.WriteTo(w)
} }
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength { if buffer.Len() < s.keySaltLength {
return E.New("bad packet") return E.New("bad packet")
} }
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key)) c := s.constructor(common.Dup(key))
/*data := buf.New()
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
data.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, data, metadata)*/
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil { if err != nil {
return err return err

View file

@ -377,6 +377,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r) return c.writer.ReadFrom(r)
} }
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
}
type clientPacketConn struct { type clientPacketConn struct {
net.Conn net.Conn
method *Method method *Method
@ -582,3 +604,19 @@ func (m *Method) newUDPSession() *udpSession {
} }
return session return session
} }
func (c *clientPacketConn) UpstreamReader() io.Reader {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -207,6 +207,28 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w) return c.reader.WriteTo(w)
} }
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte var packetHeader []byte
if s.udpCipher != nil { if s.udpCipher != nil {