More contetxt in connection

This commit is contained in:
世界 2022-04-28 08:28:38 +08:00
parent f16dd7a336
commit 5be6eb2d64
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 33 additions and 27 deletions

View file

@ -302,7 +302,7 @@ func bypass(conn net.Conn, destination *M.AddrPort) error {
}) })
} }
func (c *client) NewConnection(conn net.Conn, metadata M.Metadata) error { func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if c.bypass != "" { if c.bypass != "" {
if metadata.Destination.Addr.Family().IsFqdn() { if metadata.Destination.Addr.Family().IsFqdn() {
if c.Match(metadata.Destination.Addr.Fqdn()) { if c.Match(metadata.Destination.Addr.Fqdn()) {
@ -316,7 +316,6 @@ func (c *client) NewConnection(conn net.Conn, metadata M.Metadata) error {
} }
logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination) logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
ctx := context.Background()
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String()) serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
if err != nil { if err != nil {

View file

@ -131,16 +131,16 @@ func newServer(f *flags) (*server, error) {
return s, nil return s, nil
} }
func (s *server) NewConnection(conn net.Conn, metadata M.Metadata) error { func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if metadata.Protocol != "shadowsocks" { if metadata.Protocol != "shadowsocks" {
return s.service.NewConnection(conn, metadata) return s.service.NewConnection(ctx, conn, metadata)
} }
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination) logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination) destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
if err != nil { if err != nil {
return err return err
} }
return rw.CopyConn(context.Background(), conn, destConn) return rw.CopyConn(ctx, conn, destConn)
} }
func (s *server) HandleError(err error) { func (s *server) HandleError(err error) {

View file

@ -94,7 +94,7 @@ type localClient struct {
upstream string upstream string
} }
func (c *localClient) NewConnection(conn net.Conn, metadata M.Metadata) error { func (c *localClient) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination) logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
upstream, err := net.Dial("tcp", c.upstream) upstream, err := net.Dial("tcp", c.upstream)

View file

@ -1,6 +1,7 @@
package metadata package metadata
import ( import (
"context"
"net" "net"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -13,7 +14,7 @@ type Metadata struct {
} }
type TCPConnectionHandler interface { type TCPConnectionHandler interface {
NewConnection(conn net.Conn, metadata Metadata) error NewConnection(ctx context.Context, conn net.Conn, metadata Metadata) error
} }
type UDPHandler interface { type UDPHandler interface {

View file

@ -21,7 +21,7 @@ type Handler interface {
tcp.Handler tcp.Handler
} }
func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { func HandleRequest(ctx context.Context, request *http.Request, conn net.Conn, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error {
var httpClient *http.Client var httpClient *http.Client
for { for {
if authenticator != nil { if authenticator != nil {
@ -56,7 +56,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
return E.Cause(err, "write http response") return E.Cause(err, "write http response")
} }
metadata.Destination = destination metadata.Destination = destination
return handler.NewConnection(conn, metadata) return handler.NewConnection(ctx, conn, metadata)
} }
keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" keepAlive := strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
@ -96,7 +96,7 @@ func HandleRequest(request *http.Request, conn net.Conn, authenticator auth.Auth
go func() { go func() {
metadata.Destination = destination metadata.Destination = destination
metadata.Protocol = "http" metadata.Protocol = "http"
err = handler.NewConnection(right, metadata) err = handler.NewConnection(ctx, right, metadata)
if err != nil { if err != nil {
handler.HandleError(&tcp.Error{Conn: right, Cause: err}) handler.HandleError(&tcp.Error{Conn: right, Cause: err})
} }

View file

@ -1,6 +1,7 @@
package shadowsocks package shadowsocks
import ( import (
"context"
"net" "net"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -31,12 +32,12 @@ func NewNoneService(handler Handler) Service {
} }
} }
func (s *NoneService) NewConnection(conn net.Conn, metadata M.Metadata) error { func (s *NoneService) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
destination, err := socks.AddressSerializer.ReadAddrPort(conn) destination, err := socks.AddressSerializer.ReadAddrPort(conn)
if err != nil { if err != nil {
return err return err
} }
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
return s.handler.NewConnection(conn, metadata) return s.handler.NewConnection(ctx, conn, metadata)
} }

View file

@ -1,6 +1,7 @@
package shadowaead package shadowaead
import ( import (
"context"
"crypto/cipher" "crypto/cipher"
"io" "io"
"net" "net"
@ -73,7 +74,7 @@ func NewService(method string, key []byte, password []byte, secureRNG io.Reader,
return s, nil return s, nil
} }
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error { func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
_salt := buf.Make(s.keySaltLength) _salt := buf.Make(s.keySaltLength)
salt := common.Dup(_salt) salt := common.Dup(_salt)
@ -92,7 +93,7 @@ func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
return s.handler.NewConnection(&serverConn{ return s.handler.NewConnection(ctx, &serverConn{
Service: s, Service: s,
Conn: conn, Conn: conn,
reader: reader, reader: reader,
@ -153,7 +154,7 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
} }
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer != nil { if c.writer == nil {
return rw.ReadFrom0(c, r) return rw.ReadFrom0(c, r)
} }
return c.writer.ReadFrom(r) return c.writer.ReadFrom(r)

View file

@ -1,6 +1,7 @@
package shadowaead_2022 package shadowaead_2022
import ( import (
"context"
"crypto/cipher" "crypto/cipher"
"encoding/binary" "encoding/binary"
"io" "io"
@ -61,7 +62,7 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
return s, nil return s, nil
} }
func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error { func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
requestSalt := make([]byte, KeySaltSize) requestSalt := make([]byte, KeySaltSize)
_, err := io.ReadFull(conn, requestSalt) _, err := io.ReadFull(conn, requestSalt)
if err != nil { if err != nil {
@ -117,7 +118,7 @@ func (s *Service) NewConnection(conn net.Conn, metadata M.Metadata) error {
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
return s.handler.NewConnection(&serverConn{ return s.handler.NewConnection(ctx, &serverConn{
Service: s, Service: s,
Conn: conn, Conn: conn,
reader: reader, reader: reader,
@ -184,7 +185,7 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
} }
func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) { func (c *serverConn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writer != nil { if c.writer == nil {
return rw.ReadFrom0(c, r) return rw.ReadFrom0(c, r)
} }
return c.writer.ReadFrom(r) return c.writer.ReadFrom(r)

View file

@ -1,6 +1,7 @@
package socks package socks
import ( import (
"context"
"io" "io"
"net" "net"
"net/netip" "net/netip"
@ -34,8 +35,8 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, handler
return listener return listener
} }
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error { func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
return HandleConnection(conn, l.authenticator, l.bindAddr, l.handler, metadata) return HandleConnection(ctx, conn, l.authenticator, l.bindAddr, l.handler, metadata)
} }
func (l *Listener) Start() error { func (l *Listener) Start() error {
@ -50,7 +51,7 @@ func (l *Listener) HandleError(err error) {
l.handler.HandleError(err) l.handler.HandleError(err)
} }
func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind netip.Addr, handler Handler, metadata M.Metadata) error { func HandleConnection(ctx context.Context, conn net.Conn, authenticator auth.Authenticator, bind netip.Addr, handler Handler, metadata M.Metadata) error {
authRequest, err := ReadAuthRequest(conn) authRequest, err := ReadAuthRequest(conn)
if err != nil { if err != nil {
return E.Cause(err, "read socks auth request") return E.Cause(err, "read socks auth request")
@ -111,7 +112,7 @@ func HandleConnection(conn net.Conn, authenticator auth.Authenticator, bind neti
} }
metadata.Protocol = "socks" metadata.Protocol = "socks"
metadata.Destination = request.Destination metadata.Destination = request.Destination
return handler.NewConnection(conn, metadata) return handler.NewConnection(ctx, conn, metadata)
case CommandUDPAssociate: case CommandUDPAssociate:
network := "udp" network := "udp"
if bind.Is4() { if bind.Is4() {

View file

@ -1,6 +1,7 @@
package mixed package mixed
import ( import (
"context"
"io" "io"
"net" "net"
netHttp "net/http" netHttp "net/http"
@ -49,9 +50,9 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
return listener return listener
} }
func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error { func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if metadata.Destination != nil { if metadata.Destination != nil {
return l.handler.NewConnection(conn, metadata) return l.handler.NewConnection(ctx, conn, metadata)
} }
bufConn := buf.NewBufferedConn(conn) bufConn := buf.NewBufferedConn(conn)
header, err := bufConn.Peek(1) header, err := bufConn.Peek(1)
@ -62,7 +63,7 @@ func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
case socks.Version4: case socks.Version4:
return E.New("socks4 request dropped (TODO)") return E.New("socks4 request dropped (TODO)")
case socks.Version5: case socks.Version5:
return socks.HandleConnection(bufConn, l.authenticator, l.bindAddr, l.handler, metadata) return socks.HandleConnection(ctx, bufConn, l.authenticator, l.bindAddr, l.handler, metadata)
} }
request, err := http.ReadRequest(bufConn.Reader()) request, err := http.ReadRequest(bufConn.Reader())
@ -92,7 +93,7 @@ func (l *Listener) NewConnection(conn net.Conn, metadata M.Metadata) error {
return nil return nil
} }
return http.HandleRequest(request, bufConn, l.authenticator, l.handler, metadata) return http.HandleRequest(ctx, request, bufConn, l.authenticator, l.handler, metadata)
} }
func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error { func (l *Listener) NewPacket(packet *buf.Buffer, metadata M.Metadata) error {

View file

@ -1,6 +1,7 @@
package tcp package tcp
import ( import (
"context"
"net" "net"
"net/netip" "net/netip"
@ -108,7 +109,7 @@ func (l *Listener) loop() {
} }
go func() { go func() {
metadata.Protocol = "tcp" metadata.Protocol = "tcp"
hErr := l.handler.NewConnection(tcpConn, metadata) hErr := l.handler.NewConnection(context.Background(), tcpConn, metadata)
if hErr != nil { if hErr != nil {
l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr}) l.handler.HandleError(&Error{Conn: tcpConn, Cause: hErr})
} }