Fix close conn

This commit is contained in:
世界 2022-05-01 07:53:26 +08:00
parent a5d5d79e29
commit 118c423774
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 179 additions and 51 deletions

View file

@ -147,9 +147,9 @@ func Close(closers ...any) error {
}
switch c := closer.(type) {
case ReaderWithUpstream:
err = Close(c.Upstream())
err = Close(c.UpstreamReader())
case WriterWithUpstream:
err = Close(c.Upstream())
err = Close(c.UpstreamWriter())
}
if err != nil {
retErr = err

View file

@ -18,16 +18,16 @@ func Flush(writer io.Writer) error {
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() {
if u.WriterReplaceable() {
if writerBack == writer {
} else if setter, hasSetter := u.Upstream().(UpstreamWriterSetter); hasSetter {
} else if setter, hasSetter := u.UpstreamWriter().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)
writer = u.Upstream()
writer = u.UpstreamWriter()
continue
}
}
writerBack = writer
writer = u.Upstream()
writer = u.UpstreamWriter()
} else {
break
}
@ -46,20 +46,20 @@ func FlushVar(writerP *io.Writer) error {
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.Replaceable() {
if u.WriterReplaceable() {
if writerBack == writer {
writer = u.Upstream()
writer = u.UpstreamWriter()
writerBack = writer
*writerP = writer
continue
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
setter.SetWriter(u.Upstream())
writer = u.Upstream()
setter.SetWriter(u.UpstreamWriter())
writer = u.UpstreamWriter()
continue
}
}
writerBack = writer
writer = u.Upstream()
writer = u.UpstreamWriter()
} else {
break
}
@ -72,11 +72,11 @@ type FlushOnceWriter struct {
flushed bool
}
func (w *FlushOnceWriter) Upstream() io.Writer {
func (w *FlushOnceWriter) UpstreamWriter() io.Writer {
return w.Writer
}
func (w *FlushOnceWriter) Replaceable() bool {
func (w *FlushOnceWriter) WriterReplaceable() bool {
return w.flushed
}

View file

@ -12,11 +12,11 @@ type BufferedWriter struct {
Buffer *buf.Buffer
}
func (w *BufferedWriter) Upstream() io.Writer {
func (w *BufferedWriter) UpstreamWriter() io.Writer {
return w.Writer
}
func (w *BufferedWriter) Replaceable() bool {
func (w *BufferedWriter) WriterReplaceable() bool {
return w.Buffer == nil
}
@ -85,11 +85,11 @@ type HeaderWriter struct {
Header *buf.Buffer
}
func (w *HeaderWriter) Upstream() io.Writer {
func (w *HeaderWriter) UpstreamWriter() io.Writer {
return w.Writer
}
func (w *HeaderWriter) Replaceable() bool {
func (w *HeaderWriter) WriterReplaceable() bool {
return w.Header == nil
}

View file

@ -25,13 +25,13 @@ func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
}
}
if u, ok := writer.(common.WriterWithUpstream); ok {
if u.Replaceable() && writerBack == writer {
writer = u.Upstream()
if u.WriterReplaceable() && writerBack == writer {
writer = u.UpstreamWriter()
writerBack = writer
writerVar = &writer
continue
}
writer = u.Upstream()
writer = u.UpstreamWriter()
writerBack = writer
} else {
break

View file

@ -1,6 +1,8 @@
package rw
import "io"
import (
"github.com/sagernet/sing/common"
)
type ReadCloser interface {
CloseRead() error
@ -10,16 +12,32 @@ type WriteCloser interface {
CloseWrite() error
}
func CloseRead(conn io.Closer) error {
if closer, ok := conn.(ReadCloser); ok {
return closer.CloseRead()
func CloseRead(reader any) error {
r := reader
for {
if closer, ok := r.(ReadCloser); ok {
return closer.CloseRead()
}
if u, ok := r.(common.ReaderWithUpstream); ok {
r = u.UpstreamReader()
continue
}
break
}
return nil
return common.Close(reader)
}
func CloseWrite(conn io.Closer) error {
if closer, ok := conn.(WriteCloser); ok {
return closer.CloseWrite()
func CloseWrite(writer any) error {
w := writer
for {
if closer, ok := w.(WriteCloser); ok {
return closer.CloseWrite()
}
if u, ok := w.(common.WriterWithUpstream); ok {
w = u.UpstreamWriter()
continue
}
break
}
return nil
return common.Close(writer)
}

View file

@ -5,17 +5,17 @@ import (
)
type ReaderWithUpstream interface {
Upstream() io.Reader
Replaceable() bool
UpstreamReader() io.Reader
ReaderReplaceable() bool
}
type UpstreamReaderSetter interface {
SetUpstream(reader io.Reader)
SetReader(reader io.Reader)
}
type WriterWithUpstream interface {
Upstream() io.Writer
Replaceable() bool
UpstreamWriter() io.Writer
WriterReplaceable() bool
}
type UpstreamWriterSetter interface {

View file

@ -41,15 +41,15 @@ func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce [
}
}
func (r *Reader) Upstream() io.Reader {
func (r *Reader) UpstreamReader() io.Reader {
return r.upstream
}
func (r *Reader) Replaceable() bool {
func (r *Reader) ReaderReplaceable() bool {
return false
}
func (r *Reader) SetUpstream(reader io.Reader) {
func (r *Reader) SetReader(reader io.Reader) {
r.upstream = reader
}
@ -226,11 +226,11 @@ func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buf
}
}
func (w *Writer) Upstream() io.Writer {
func (w *Writer) UpstreamWriter() io.Writer {
return w.upstream
}
func (w *Writer) Replaceable() bool {
func (w *Writer) WriterReplaceable() bool {
return false
}
@ -299,11 +299,11 @@ type BufferedWriter struct {
index int
}
func (w *BufferedWriter) Upstream() io.Writer {
func (w *BufferedWriter) UpstreamWriter() io.Writer {
return w.upstream
}
func (w *BufferedWriter) Replaceable() bool {
func (w *BufferedWriter) WriterReplaceable() bool {
return w.index == 0
}

View file

@ -298,6 +298,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r)
}
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
}
type clientPacketConn struct {
*Method
net.Conn
@ -329,3 +351,19 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
}
return socks.AddressSerializer.ReadAddrPort(buffer)
}
func (c *clientPacketConn) UpstreamReader() io.Reader {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -168,22 +168,34 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
if buffer.Len() < s.keySaltLength {
return E.New("bad packet")
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
/*data := buf.New()
packet, err := c.Open(data.Index(0), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
}
data.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, data, metadata)*/
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err

View file

@ -377,6 +377,28 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r)
}
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
}
type clientPacketConn struct {
net.Conn
method *Method
@ -582,3 +604,19 @@ func (m *Method) newUDPSession() *udpSession {
}
return session
}
func (c *clientPacketConn) UpstreamReader() io.Reader {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -207,6 +207,28 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
}
func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
var packetHeader []byte
if s.udpCipher != nil {