diff --git a/go.mod b/go.mod index 0fe2936..73d6386 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowtls go 1.20 require ( - github.com/sagernet/sing v0.6.0 + github.com/sagernet/sing v0.6.3 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 ) diff --git a/go.sum b/go.sum index 348862d..71e015c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/sagernet/sing v0.6.0 h1:jT55zAXrG7H3x+s/FlrC15xQy3LcmuZ2GGA9+8IJdt0= -github.com/sagernet/sing v0.6.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.3 h1:J1spMc6LMlqUvRjWjvNMAcbvACDneqxB9zxfLuS0UTE= +github.com/sagernet/sing v0.6.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= diff --git a/service.go b/service.go index 15a1aaa..84f12f9 100644 --- a/service.go +++ b/service.go @@ -27,10 +27,19 @@ type Service struct { handshake HandshakeConfig handshakeForServerName map[string]HandshakeConfig strictMode bool + wildcardSNI WildcardSNI handler N.TCPConnectionHandlerEx logger logger.ContextLogger } +type WildcardSNI int + +const ( + WildcardSNIOff WildcardSNI = iota + WildcardSNIAuthed + WildcardSNIAll +) + type ServiceConfig struct { Version int Password string // for protocol version 2 @@ -38,6 +47,7 @@ type ServiceConfig struct { Handshake HandshakeConfig HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3 StrictMode bool // for protocol version 3 + WildcardSNI WildcardSNI // for protocol version 3 Handler N.TCPConnectionHandlerEx Logger logger.ContextLogger } @@ -60,11 +70,12 @@ func NewService(config ServiceConfig) (*Service, error) { handshake: config.Handshake, handshakeForServerName: config.HandshakeForServerName, strictMode: config.StrictMode, + wildcardSNI: config.WildcardSNI, handler: config.Handler, logger: config.Logger, } - if !service.handshake.Server.IsValid() { + if !service.handshake.Server.IsValid() && service.wildcardSNI == WildcardSNIOff { return nil, E.New("missing default handshake information") } @@ -84,16 +95,6 @@ func NewService(config ServiceConfig) (*Service, error) { return service, nil } -func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig { - serverName, err := extractServerName(clientHelloFrame.Bytes()) - if err == nil { - if customHandshake, found := s.handshakeForServerName[serverName]; found { - return customHandshake - } - } - return s.handshake -} - func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) error { switch s.version { default: @@ -127,8 +128,17 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc if err != nil { return E.Cause(err, "read client handshake") } - - handshakeConfig := s.selectHandshake(clientHelloFrame) + serverName, err := extractServerName(clientHelloFrame.Bytes()) + var handshakeConfig HandshakeConfig + if err == nil { + if customHandshake, found := s.handshakeForServerName[serverName]; found { + handshakeConfig = customHandshake + } else { + handshakeConfig = s.handshake + } + } else { + handshakeConfig = s.handshake + } handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) if err != nil { return E.Cause(err, "server handshake") @@ -154,28 +164,56 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc if err != nil { return E.Cause(err, "read client handshake") } - - handshakeConfig := s.selectHandshake(clientHelloFrame) - handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) + defer clientHelloFrame.Release() + serverName, err := extractServerName(clientHelloFrame.Bytes()) if err != nil { - return E.Cause(err, "server handshake") + return E.Cause(err, "extract server name") } - - _, err = handshakeConn.Write(clientHelloFrame.Bytes()) - if err != nil { - clientHelloFrame.Release() - return E.Cause(err, "write client handshake") + var ( + handshakeConfig HandshakeConfig + isCustom bool + ) + if customHandshake, found := s.handshakeForServerName[serverName]; found { + handshakeConfig = customHandshake + isCustom = true + } else { + handshakeConfig = s.handshake + if s.wildcardSNI != WildcardSNIOff { + handshakeConfig.Server = M.Socksaddr{ + Fqdn: serverName, + Port: 443, + } + } } + var handshakeConn net.Conn user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users) if err != nil { s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed")) - return bufio.CopyConn(ctx, conn, handshakeConn) + if s.wildcardSNI == WildcardSNIAll || isCustom { + handshakeConn, err = handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) + } else { + handshakeConn, err = s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server) + } + if err != nil { + return E.Cause(err, "server handshake") + } + return bufio.CopyConn(ctx, bufio.NewCachedConn(conn, clientHelloFrame), handshakeConn) } if user.Name != "" { ctx = auth.ContextWithUser(ctx, user.Name) } s.logger.TraceContext(ctx, "client hello verify success") + + handshakeConn, err = handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) + if err != nil { + return E.Cause(err, "server handshake") + } + + _, err = handshakeConn.Write(clientHelloFrame.Bytes()) clientHelloFrame.Release() + if err != nil { + return E.Cause(err, "write client handshake") + } var serverHelloFrame *buf.Buffer serverHelloFrame, err = extractFrame(handshakeConn)