Fix ss-server close

This commit is contained in:
世界 2022-05-15 23:52:19 +08:00
parent 2be8304e36
commit cd0e6406c3
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
14 changed files with 56 additions and 366 deletions

View file

@ -145,10 +145,8 @@ func Close(closers ...any) error {
err = c.Close()
}
switch c := closer.(type) {
case ReaderWithUpstream:
err = Close(c.UpstreamReader())
case WriterWithUpstream:
err = Close(c.UpstreamWriter())
case WithUpstream:
err = Close(c.Upstream())
}
if err != nil {
retErr = err

View file

@ -1,95 +0,0 @@
package common
import (
"io"
)
type Flusher interface {
Flush() error
}
func Flush(writer io.Writer) error {
writerBack := writer
for {
if f, ok := writer.(Flusher); ok {
err := f.Flush()
if err != nil {
return err
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.WriterReplaceable() {
if writerBack == writer {
} else if setter, hasSetter := u.UpstreamWriter().(UpstreamWriterSetter); hasSetter {
setter.SetWriter(writerBack)
writer = u.UpstreamWriter()
continue
}
}
writerBack = writer
writer = u.UpstreamWriter()
} else {
break
}
}
return nil
}
func FlushVar(writerP *io.Writer) error {
writer := *writerP
writerBack := writer
for {
if f, ok := writer.(Flusher); ok {
err := f.Flush()
if err != nil {
return err
}
}
if u, ok := writer.(WriterWithUpstream); ok {
if u.WriterReplaceable() {
if writerBack == writer {
writer = u.UpstreamWriter()
writerBack = writer
*writerP = writer
continue
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
setter.SetWriter(u.UpstreamWriter())
writer = u.UpstreamWriter()
continue
}
}
writerBack = writer
writer = u.UpstreamWriter()
} else {
break
}
}
return nil
}
type FlushOnceWriter struct {
io.Writer
flushed bool
}
func (w *FlushOnceWriter) UpstreamWriter() io.Writer {
return w.Writer
}
func (w *FlushOnceWriter) WriterReplaceable() bool {
return w.flushed
}
func (w *FlushOnceWriter) Write(p []byte) (n int, err error) {
if w.flushed {
return w.Writer.Write(p)
}
n, err = w.Writer.Write(p)
if n > 0 {
err = FlushVar(&w.Writer)
}
if err == nil {
w.flushed = true
}
return
}

View file

@ -10,7 +10,6 @@ import (
)
type CachedReader interface {
common.ReaderWithUpstream
ReadCached() *buf.Buffer
}
@ -62,22 +61,10 @@ func (c *BufferedConn) ReadFrom(r io.Reader) (n int64, err error) {
return Copy(c.Conn, r)
}
func (c *BufferedConn) UpstreamReader() io.Reader {
func (c *BufferedConn) Upstream() any {
return c.Conn
}
func (c *BufferedConn) ReaderReplaceable() bool {
return c.Buffer == nil
}
func (c *BufferedConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *BufferedConn) WriterReplaceable() bool {
return true
}
type BufferedReader struct {
Reader io.Reader
Buffer *buf.Buffer
@ -114,27 +101,19 @@ func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
return
}
func (w *BufferedReader) UpstreamReader() io.Reader {
func (w *BufferedReader) Upstream() any {
return w.Reader
}
func (w *BufferedReader) ReaderReplaceable() bool {
return w.Buffer == nil
}
type BufferedWriter struct {
Writer io.Writer
Buffer *buf.Buffer
}
func (w *BufferedWriter) UpstreamWriter() io.Writer {
func (w *BufferedWriter) Upstream() any {
return w.Writer
}
func (w *BufferedWriter) WriterReplaceable() bool {
return w.Buffer == nil
}
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
if w.Buffer == nil {
return w.Writer.Write(p)
@ -200,14 +179,10 @@ type HeaderWriter struct {
Header *buf.Buffer
}
func (w *HeaderWriter) UpstreamWriter() io.Writer {
func (w *HeaderWriter) Upstream() any {
return w.Writer
}
func (w *HeaderWriter) WriterReplaceable() bool {
return w.Header == nil
}
func (w *HeaderWriter) Write(p []byte) (n int, err error) {
if w.Header == nil {
return w.Writer.Write(p)

View file

@ -4,7 +4,6 @@ import (
"context"
"io"
"net"
"os"
"runtime"
"github.com/sagernet/sing/common"
@ -12,35 +11,6 @@ import (
"github.com/sagernet/sing/common/task"
)
func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
writer := *writerVar
writerBack := writer
for {
if w, ok := writer.(io.ReaderFrom); ok {
return w.ReadFrom(reader)
}
if f, ok := writer.(common.Flusher); ok {
err := f.Flush()
if err != nil {
return 0, err
}
}
if u, ok := writer.(common.WriterWithUpstream); ok {
if u.WriterReplaceable() && writerBack == writer {
writer = u.UpstreamWriter()
writerBack = writer
writerVar = &writer
continue
}
writer = u.UpstreamWriter()
writerBack = writer
} else {
break
}
}
return 0, os.ErrInvalid
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
if pc, inPc := conn.(net.PacketConn); inPc {
if destPc, outPc := dest.(net.PacketConn); outPc {

View file

@ -13,31 +13,15 @@ type WriteCloser interface {
}
func CloseRead(reader any) error {
r := reader
for {
if closer, ok := r.(ReadCloser); ok {
return closer.CloseRead()
}
if u, ok := r.(common.ReaderWithUpstream); ok {
r = u.UpstreamReader()
continue
}
break
if c, ok := common.Cast[ReadCloser](reader); ok {
return c.CloseRead()
}
return common.Close(reader)
}
func CloseWrite(writer any) error {
w := writer
for {
if closer, ok := w.(WriteCloser); ok {
return closer.CloseWrite()
}
if u, ok := w.(common.WriterWithUpstream); ok {
w = u.UpstreamWriter()
continue
}
break
if c, ok := common.Cast[WriteCloser](writer); ok {
return c.CloseWrite()
}
return common.Close(writer)
}

View file

@ -1,23 +1,16 @@
package common
import (
"io"
)
type ReaderWithUpstream interface {
UpstreamReader() io.Reader
ReaderReplaceable() bool
type WithUpstream interface {
Upstream() any
}
type UpstreamReaderSetter interface {
SetReader(reader io.Reader)
}
type WriterWithUpstream interface {
UpstreamWriter() io.Writer
WriterReplaceable() bool
}
type UpstreamWriterSetter interface {
SetWriter(writer io.Writer)
func Cast[T any](obj any) (T, bool) {
if c, ok := obj.(T); ok {
return c, true
}
if u, ok := obj.(WithUpstream); ok {
return Cast[T](u.Upstream())
}
var defaultValue T
return defaultValue, false
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -33,8 +34,8 @@ type ServerConnError struct {
}
func (e *ServerConnError) Close() error {
if tcpConn, ok := e.Conn.(*net.TCPConn); ok {
tcpConn.SetLinger(0)
if conn, ok := common.Cast[*net.TCPConn](e.Conn); ok {
conn.SetLinger(0)
}
return e.Conn.Close()
}

View file

@ -53,18 +53,10 @@ func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce [
}
}
func (r *Reader) UpstreamReader() io.Reader {
func (r *Reader) Upstream() any {
return r.upstream
}
func (r *Reader) ReaderReplaceable() bool {
return false
}
func (r *Reader) SetReader(reader io.Reader) {
r.upstream = reader
}
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
if r.cached > 0 {
writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
@ -295,18 +287,10 @@ func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buf
}
}
func (w *Writer) UpstreamWriter() io.Writer {
func (w *Writer) Upstream() any {
return w.upstream
}
func (w *Writer) WriterReplaceable() bool {
return false
}
func (w *Writer) SetWriter(writer io.Writer) {
w.upstream = writer
}
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
for {
offset := Overhead + PacketLengthBufferSize

View file

@ -279,26 +279,8 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r)
}
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
@ -375,18 +357,6 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return len(p), nil
}
func (c *clientPacketConn) UpstreamReader() io.Reader {
func (c *clientPacketConn) Upstream() any {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -186,26 +186,8 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
func (c *serverConn) Upstream() any {
return c.Conn
}
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {

View file

@ -394,26 +394,8 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.ReadFrom(r)
}
func (c *clientConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *clientConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *clientConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *clientConn) WriterReplaceable() bool {
return c.writer != nil
func (c *clientConn) Upstream() any {
return c.Conn
}
type clientPacketConn struct {
@ -732,18 +714,6 @@ func (m *Method) newUDPSession() *udpSession {
return session
}
func (c *clientPacketConn) UpstreamReader() io.Reader {
func (c *clientPacketConn) Upstream() any {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -279,26 +279,8 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
return c.reader.WriteTo(w)
}
func (c *serverConn) UpstreamReader() io.Reader {
if c.reader == nil {
return c.Conn
}
return c.reader
}
func (c *serverConn) ReaderReplaceable() bool {
return c.reader != nil
}
func (c *serverConn) UpstreamWriter() io.Writer {
if c.writer == nil {
return c.Conn
}
return c.writer
}
func (c *serverConn) WriterReplaceable() bool {
return c.writer != nil
func (c *serverConn) Upstream() any {
return c.Conn
}
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {

View file

@ -303,22 +303,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
return c.Conn.Write(p)
}
func (c *clientConn) UpstreamReader() io.Reader {
func (c *clientConn) Upstream() any {
return c.Conn
}
func (c *clientConn) ReaderReplaceable() bool {
return false
}
func (c *clientConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientConn) WriterReplaceable() bool {
return false
}
type clientPacketConn struct {
*Method
net.Conn
@ -400,18 +388,6 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return len(p), nil
}
func (c *clientPacketConn) UpstreamReader() io.Reader {
func (c *clientPacketConn) Upstream() any {
return c.Conn
}
func (c *clientPacketConn) ReaderReplaceable() bool {
return false
}
func (c *clientPacketConn) UpstreamWriter() io.Writer {
return c.Conn
}
func (c *clientPacketConn) WriterReplaceable() bool {
return false
}

View file

@ -25,23 +25,6 @@ type Listener struct {
*net.TCPListener
}
type Error struct {
Conn net.Conn
Cause error
}
func (e *Error) Error() string {
return e.Cause.Error()
}
func (e *Error) Unwrap() error {
return e.Cause
}
func (e *Error) Close() error {
return common.Close(e.Conn)
}
func NewTCPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
listener := &Listener{
bind: listen,
@ -112,9 +95,26 @@ func (l *Listener) loop() {
metadata.Protocol = "tcp"
hErr := l.handler.NewConnection(context.Background(), tcpConn, metadata)
if hErr != nil {
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
l.handler.HandleError(hErr)
}
debug.Free()
}()
}
}
type Error struct {
Conn net.Conn
Cause error
}
func (e *Error) Error() string {
return e.Cause.Error()
}
func (e *Error) Unwrap() error {
return e.Cause
}
func (e *Error) Close() error {
return common.Close(e.Conn)
}