mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Fix ss-server close
This commit is contained in:
parent
2be8304e36
commit
cd0e6406c3
14 changed files with 56 additions and 366 deletions
|
@ -145,10 +145,8 @@ func Close(closers ...any) error {
|
|||
err = c.Close()
|
||||
}
|
||||
switch c := closer.(type) {
|
||||
case ReaderWithUpstream:
|
||||
err = Close(c.UpstreamReader())
|
||||
case WriterWithUpstream:
|
||||
err = Close(c.UpstreamWriter())
|
||||
case WithUpstream:
|
||||
err = Close(c.Upstream())
|
||||
}
|
||||
if err != nil {
|
||||
retErr = err
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type Flusher interface {
|
||||
Flush() error
|
||||
}
|
||||
|
||||
func Flush(writer io.Writer) error {
|
||||
writerBack := writer
|
||||
for {
|
||||
if f, ok := writer.(Flusher); ok {
|
||||
err := f.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u, ok := writer.(WriterWithUpstream); ok {
|
||||
if u.WriterReplaceable() {
|
||||
if writerBack == writer {
|
||||
} else if setter, hasSetter := u.UpstreamWriter().(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(writerBack)
|
||||
writer = u.UpstreamWriter()
|
||||
continue
|
||||
}
|
||||
}
|
||||
writerBack = writer
|
||||
writer = u.UpstreamWriter()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FlushVar(writerP *io.Writer) error {
|
||||
writer := *writerP
|
||||
writerBack := writer
|
||||
for {
|
||||
if f, ok := writer.(Flusher); ok {
|
||||
err := f.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if u, ok := writer.(WriterWithUpstream); ok {
|
||||
if u.WriterReplaceable() {
|
||||
if writerBack == writer {
|
||||
writer = u.UpstreamWriter()
|
||||
writerBack = writer
|
||||
*writerP = writer
|
||||
continue
|
||||
} else if setter, hasSetter := writerBack.(UpstreamWriterSetter); hasSetter {
|
||||
setter.SetWriter(u.UpstreamWriter())
|
||||
writer = u.UpstreamWriter()
|
||||
continue
|
||||
}
|
||||
}
|
||||
writerBack = writer
|
||||
writer = u.UpstreamWriter()
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type FlushOnceWriter struct {
|
||||
io.Writer
|
||||
flushed bool
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) UpstreamWriter() io.Writer {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) WriterReplaceable() bool {
|
||||
return w.flushed
|
||||
}
|
||||
|
||||
func (w *FlushOnceWriter) Write(p []byte) (n int, err error) {
|
||||
if w.flushed {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
n, err = w.Writer.Write(p)
|
||||
if n > 0 {
|
||||
err = FlushVar(&w.Writer)
|
||||
}
|
||||
if err == nil {
|
||||
w.flushed = true
|
||||
}
|
||||
return
|
||||
}
|
|
@ -10,7 +10,6 @@ import (
|
|||
)
|
||||
|
||||
type CachedReader interface {
|
||||
common.ReaderWithUpstream
|
||||
ReadCached() *buf.Buffer
|
||||
}
|
||||
|
||||
|
@ -62,22 +61,10 @@ func (c *BufferedConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
return Copy(c.Conn, r)
|
||||
}
|
||||
|
||||
func (c *BufferedConn) UpstreamReader() io.Reader {
|
||||
func (c *BufferedConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *BufferedConn) ReaderReplaceable() bool {
|
||||
return c.Buffer == nil
|
||||
}
|
||||
|
||||
func (c *BufferedConn) UpstreamWriter() io.Writer {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *BufferedConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type BufferedReader struct {
|
||||
Reader io.Reader
|
||||
Buffer *buf.Buffer
|
||||
|
@ -114,27 +101,19 @@ func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (w *BufferedReader) UpstreamReader() io.Reader {
|
||||
func (w *BufferedReader) Upstream() any {
|
||||
return w.Reader
|
||||
}
|
||||
|
||||
func (w *BufferedReader) ReaderReplaceable() bool {
|
||||
return w.Buffer == nil
|
||||
}
|
||||
|
||||
type BufferedWriter struct {
|
||||
Writer io.Writer
|
||||
Buffer *buf.Buffer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) UpstreamWriter() io.Writer {
|
||||
func (w *BufferedWriter) Upstream() any {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) WriterReplaceable() bool {
|
||||
return w.Buffer == nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Buffer == nil {
|
||||
return w.Writer.Write(p)
|
||||
|
@ -200,14 +179,10 @@ type HeaderWriter struct {
|
|||
Header *buf.Buffer
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) UpstreamWriter() io.Writer {
|
||||
func (w *HeaderWriter) Upstream() any {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) WriterReplaceable() bool {
|
||||
return w.Header == nil
|
||||
}
|
||||
|
||||
func (w *HeaderWriter) Write(p []byte) (n int, err error) {
|
||||
if w.Header == nil {
|
||||
return w.Writer.Write(p)
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -12,35 +11,6 @@ import (
|
|||
"github.com/sagernet/sing/common/task"
|
||||
)
|
||||
|
||||
func ReadFromVar(writerVar *io.Writer, reader io.Reader) (int64, error) {
|
||||
writer := *writerVar
|
||||
writerBack := writer
|
||||
for {
|
||||
if w, ok := writer.(io.ReaderFrom); ok {
|
||||
return w.ReadFrom(reader)
|
||||
}
|
||||
if f, ok := writer.(common.Flusher); ok {
|
||||
err := f.Flush()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if u, ok := writer.(common.WriterWithUpstream); ok {
|
||||
if u.WriterReplaceable() && writerBack == writer {
|
||||
writer = u.UpstreamWriter()
|
||||
writerBack = writer
|
||||
writerVar = &writer
|
||||
continue
|
||||
}
|
||||
writer = u.UpstreamWriter()
|
||||
writerBack = writer
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
if pc, inPc := conn.(net.PacketConn); inPc {
|
||||
if destPc, outPc := dest.(net.PacketConn); outPc {
|
||||
|
|
|
@ -13,31 +13,15 @@ type WriteCloser interface {
|
|||
}
|
||||
|
||||
func CloseRead(reader any) error {
|
||||
r := reader
|
||||
for {
|
||||
if closer, ok := r.(ReadCloser); ok {
|
||||
return closer.CloseRead()
|
||||
}
|
||||
if u, ok := r.(common.ReaderWithUpstream); ok {
|
||||
r = u.UpstreamReader()
|
||||
continue
|
||||
}
|
||||
break
|
||||
if c, ok := common.Cast[ReadCloser](reader); ok {
|
||||
return c.CloseRead()
|
||||
}
|
||||
return common.Close(reader)
|
||||
}
|
||||
|
||||
func CloseWrite(writer any) error {
|
||||
w := writer
|
||||
for {
|
||||
if closer, ok := w.(WriteCloser); ok {
|
||||
return closer.CloseWrite()
|
||||
}
|
||||
if u, ok := w.(common.WriterWithUpstream); ok {
|
||||
w = u.UpstreamWriter()
|
||||
continue
|
||||
}
|
||||
break
|
||||
if c, ok := common.Cast[WriteCloser](writer); ok {
|
||||
return c.CloseWrite()
|
||||
}
|
||||
return common.Close(writer)
|
||||
}
|
||||
|
|
|
@ -1,23 +1,16 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ReaderWithUpstream interface {
|
||||
UpstreamReader() io.Reader
|
||||
ReaderReplaceable() bool
|
||||
type WithUpstream interface {
|
||||
Upstream() any
|
||||
}
|
||||
|
||||
type UpstreamReaderSetter interface {
|
||||
SetReader(reader io.Reader)
|
||||
}
|
||||
|
||||
type WriterWithUpstream interface {
|
||||
UpstreamWriter() io.Writer
|
||||
WriterReplaceable() bool
|
||||
}
|
||||
|
||||
type UpstreamWriterSetter interface {
|
||||
SetWriter(writer io.Writer)
|
||||
func Cast[T any](obj any) (T, bool) {
|
||||
if c, ok := obj.(T); ok {
|
||||
return c, true
|
||||
}
|
||||
if u, ok := obj.(WithUpstream); ok {
|
||||
return Cast[T](u.Upstream())
|
||||
}
|
||||
var defaultValue T
|
||||
return defaultValue, false
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
@ -33,8 +34,8 @@ type ServerConnError struct {
|
|||
}
|
||||
|
||||
func (e *ServerConnError) Close() error {
|
||||
if tcpConn, ok := e.Conn.(*net.TCPConn); ok {
|
||||
tcpConn.SetLinger(0)
|
||||
if conn, ok := common.Cast[*net.TCPConn](e.Conn); ok {
|
||||
conn.SetLinger(0)
|
||||
}
|
||||
return e.Conn.Close()
|
||||
}
|
||||
|
|
|
@ -53,18 +53,10 @@ func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce [
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Reader) UpstreamReader() io.Reader {
|
||||
func (r *Reader) Upstream() any {
|
||||
return r.upstream
|
||||
}
|
||||
|
||||
func (r *Reader) ReaderReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *Reader) SetReader(reader io.Reader) {
|
||||
r.upstream = reader
|
||||
}
|
||||
|
||||
func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
|
||||
if r.cached > 0 {
|
||||
writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
|
||||
|
@ -295,18 +287,10 @@ func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buf
|
|||
}
|
||||
}
|
||||
|
||||
func (w *Writer) UpstreamWriter() io.Writer {
|
||||
func (w *Writer) Upstream() any {
|
||||
return w.upstream
|
||||
}
|
||||
|
||||
func (w *Writer) WriterReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *Writer) SetWriter(writer io.Writer) {
|
||||
w.upstream = writer
|
||||
}
|
||||
|
||||
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
for {
|
||||
offset := Overhead + PacketLengthBufferSize
|
||||
|
|
|
@ -279,26 +279,8 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamReader() io.Reader {
|
||||
if c.reader == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.reader
|
||||
}
|
||||
|
||||
func (c *clientConn) ReaderReplaceable() bool {
|
||||
return c.reader != nil
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamWriter() io.Writer {
|
||||
if c.writer == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.writer
|
||||
}
|
||||
|
||||
func (c *clientConn) WriterReplaceable() bool {
|
||||
return c.writer != nil
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
@ -375,18 +357,6 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamReader() io.Reader {
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReaderReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamWriter() io.Writer {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WriterReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -186,26 +186,8 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *serverConn) UpstreamReader() io.Reader {
|
||||
if c.reader == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.reader
|
||||
}
|
||||
|
||||
func (c *serverConn) ReaderReplaceable() bool {
|
||||
return c.reader != nil
|
||||
}
|
||||
|
||||
func (c *serverConn) UpstreamWriter() io.Writer {
|
||||
if c.writer == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.writer
|
||||
}
|
||||
|
||||
func (c *serverConn) WriterReplaceable() bool {
|
||||
return c.writer != nil
|
||||
func (c *serverConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
|
|
|
@ -394,26 +394,8 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
return c.writer.ReadFrom(r)
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamReader() io.Reader {
|
||||
if c.reader == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.reader
|
||||
}
|
||||
|
||||
func (c *clientConn) ReaderReplaceable() bool {
|
||||
return c.reader != nil
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamWriter() io.Writer {
|
||||
if c.writer == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.writer
|
||||
}
|
||||
|
||||
func (c *clientConn) WriterReplaceable() bool {
|
||||
return c.writer != nil
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
|
@ -732,18 +714,6 @@ func (m *Method) newUDPSession() *udpSession {
|
|||
return session
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamReader() io.Reader {
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReaderReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamWriter() io.Writer {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WriterReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -279,26 +279,8 @@ func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
|
|||
return c.reader.WriteTo(w)
|
||||
}
|
||||
|
||||
func (c *serverConn) UpstreamReader() io.Reader {
|
||||
if c.reader == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.reader
|
||||
}
|
||||
|
||||
func (c *serverConn) ReaderReplaceable() bool {
|
||||
return c.reader != nil
|
||||
}
|
||||
|
||||
func (c *serverConn) UpstreamWriter() io.Writer {
|
||||
if c.writer == nil {
|
||||
return c.Conn
|
||||
}
|
||||
return c.writer
|
||||
}
|
||||
|
||||
func (c *serverConn) WriterReplaceable() bool {
|
||||
return c.writer != nil
|
||||
func (c *serverConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (s *Service) NewPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
|
||||
|
|
|
@ -303,22 +303,10 @@ func (c *clientConn) Write(p []byte) (n int, err error) {
|
|||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamReader() io.Reader {
|
||||
func (c *clientConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientConn) ReaderReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientConn) UpstreamWriter() io.Writer {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientConn) WriterReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type clientPacketConn struct {
|
||||
*Method
|
||||
net.Conn
|
||||
|
@ -400,18 +388,6 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamReader() io.Reader {
|
||||
func (c *clientPacketConn) Upstream() any {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) ReaderReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) UpstreamWriter() io.Writer {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func (c *clientPacketConn) WriterReplaceable() bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -25,23 +25,6 @@ type Listener struct {
|
|||
*net.TCPListener
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Conn net.Conn
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Cause.Error()
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
func (e *Error) Close() error {
|
||||
return common.Close(e.Conn)
|
||||
}
|
||||
|
||||
func NewTCPListener(listen netip.AddrPort, handler Handler, options ...Option) *Listener {
|
||||
listener := &Listener{
|
||||
bind: listen,
|
||||
|
@ -112,9 +95,26 @@ func (l *Listener) loop() {
|
|||
metadata.Protocol = "tcp"
|
||||
hErr := l.handler.NewConnection(context.Background(), tcpConn, metadata)
|
||||
if hErr != nil {
|
||||
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
|
||||
l.handler.HandleError(hErr)
|
||||
}
|
||||
debug.Free()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Conn net.Conn
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Cause.Error()
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
func (e *Error) Close() error {
|
||||
return common.Close(e.Conn)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue