Add multiple server names and multi-user

This commit is contained in:
世界 2023-02-21 15:45:14 +08:00
parent e986e9cd9e
commit 9cb3e9e0ed
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 115 additions and 44 deletions

View file

@ -9,6 +9,7 @@ import (
"os" "os"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/debug" "github.com/sagernet/sing/common/debug"
@ -21,13 +22,24 @@ import (
type ServiceConfig struct { type ServiceConfig struct {
Version int Version int
Password string Password string // for protocol version 2
HandshakeServer M.Socksaddr Users []User // for protocol version 3
HandshakeDialer N.Dialer Handshake HandshakeConfig
HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3
Handler Handler Handler Handler
Logger logger.ContextLogger Logger logger.ContextLogger
} }
type User struct {
Name string
Password string
}
type HandshakeConfig struct {
Server M.Socksaddr
Dialer N.Dialer
}
type Handler interface { type Handler interface {
N.TCPConnectionHandler N.TCPConnectionHandler
E.Handler E.Handler
@ -36,8 +48,9 @@ type Handler interface {
type Service struct { type Service struct {
version int version int
password string password string
handshakeServer M.Socksaddr users []User
handshakeDialer N.Dialer handshake HandshakeConfig
handshakeForServerName map[string]HandshakeConfig
handler Handler handler Handler
logger logger.ContextLogger logger logger.ContextLogger
} }
@ -46,16 +59,26 @@ func NewService(config ServiceConfig) (*Service, error) {
service := &Service{ service := &Service{
version: config.Version, version: config.Version,
password: config.Password, password: config.Password,
handshakeServer: config.HandshakeServer, users: config.Users,
handshakeDialer: config.HandshakeDialer, handshake: config.Handshake,
handshakeForServerName: config.HandshakeForServerName,
handler: config.Handler, handler: config.Handler,
logger: config.Logger, logger: config.Logger,
} }
if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil {
if !service.handshake.Server.IsValid() {
return nil, E.New("missing default handshake information")
}
if service.handler == nil || service.logger == nil {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
switch config.Version { switch config.Version {
case 1, 2, 3: case 1, 2:
case 3:
if len(service.users) == 0 {
return nil, E.New("missing users")
}
default: default:
return nil, E.New("unknown protocol version: ", config.Version) return nil, E.New("unknown protocol version: ", config.Version)
} }
@ -63,15 +86,26 @@ func NewService(config ServiceConfig) (*Service, error) {
return service, nil return service, nil
} }
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig {
handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer) serverName, err := extractServerName(clientHelloFrame.Bytes())
if err != nil { if err == nil {
return E.Cause(err, "server handshake") if customHandshake, found := s.handshakeForServerName[serverName]; found {
return customHandshake
} }
}
return s.handshake
}
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
switch s.version { switch s.version {
default: default:
fallthrough fallthrough
case 1: case 1:
handshakeConn, err := s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server)
if err != nil {
return E.Cause(err, "server handshake")
}
var group task.Group var group task.Group
group.Append("client handshake", func(ctx context.Context) error { group.Append("client handshake", func(ctx context.Context) error {
return copyUntilHandshakeFinished(handshakeConn, conn) return copyUntilHandshakeFinished(handshakeConn, conn)
@ -90,10 +124,20 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
s.logger.TraceContext(ctx, "handshake finished") s.logger.TraceContext(ctx, "handshake finished")
return s.handler.NewConnection(ctx, conn, metadata) return s.handler.NewConnection(ctx, conn, metadata)
case 2: case 2:
clientHelloFrame, err := extractFrame(conn)
if err != nil {
return E.Cause(err, "read client handshake")
}
handshakeConfig := s.selectHandshake(clientHelloFrame)
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
if err != nil {
return E.Cause(err, "server handshake")
}
hashConn := newHashWriteConn(conn, s.password) hashConn := newHashWriteConn(conn, s.password)
go bufio.Copy(hashConn, handshakeConn) go bufio.Copy(hashConn, handshakeConn)
var request *buf.Buffer var request *buf.Buffer
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2) request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, bufio.NewCachedConn(conn, clientHelloFrame), hashConn, 2)
if err == nil { if err == nil {
s.logger.TraceContext(ctx, "handshake finished") s.logger.TraceContext(ctx, "handshake finished")
handshakeConn.Close() handshakeConn.Close()
@ -106,21 +150,30 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
return err return err
} }
case 3: case 3:
var clientHelloFrame *buf.Buffer clientHelloFrame, err := extractFrame(conn)
clientHelloFrame, err = extractFrame(conn)
if err != nil { if err != nil {
return E.Cause(err, "read client handshake") return E.Cause(err, "read client handshake")
} }
handshakeConfig := s.selectHandshake(clientHelloFrame)
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
if err != nil {
return E.Cause(err, "server handshake")
}
_, err = handshakeConn.Write(clientHelloFrame.Bytes()) _, err = handshakeConn.Write(clientHelloFrame.Bytes())
if err != nil { if err != nil {
clientHelloFrame.Release() clientHelloFrame.Release()
return E.Cause(err, "write client handshake") return E.Cause(err, "write client handshake")
} }
err = verifyClientHello(clientHelloFrame.Bytes(), s.password) user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users)
if err != nil { if err != nil {
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed")) s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
return bufio.CopyConn(ctx, conn, handshakeConn) return bufio.CopyConn(ctx, conn, handshakeConn)
} }
if user.Name != "" {
ctx = auth.ContextWithUser(ctx, user.Name)
}
s.logger.TraceContext(ctx, "client hello verify success") s.logger.TraceContext(ctx, "client hello verify success")
clientHelloFrame.Release() clientHelloFrame.Release()
@ -152,12 +205,12 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
if debug.Enabled { if debug.Enabled {
s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom)) s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom))
} }
hmacWrite := hmac.New(sha1.New, []byte(s.password)) hmacWrite := hmac.New(sha1.New, []byte(user.Password))
hmacWrite.Write(serverRandom) hmacWrite.Write(serverRandom)
hmacAdd := hmac.New(sha1.New, []byte(s.password)) hmacAdd := hmac.New(sha1.New, []byte(user.Password))
hmacAdd.Write(serverRandom) hmacAdd.Write(serverRandom)
hmacAdd.Write([]byte("S")) hmacAdd.Write([]byte("S"))
hmacVerify := hmac.New(sha1.New, []byte(s.password)) hmacVerify := hmac.New(sha1.New, []byte(user.Password))
hmacVerifyReset := func() { hmacVerifyReset := func() {
hmacVerify.Reset() hmacVerify.Reset()
hmacVerify.Write(serverRandom) hmacVerify.Write(serverRandom)
@ -177,7 +230,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
return cErr return cErr
}) })
group.Append("server handshake relay", func(ctx context.Context) error { group.Append("server handshake relay", func(ctx context.Context) error {
cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite) cErr := copyByFrameWithModification(handshakeConn, conn, user.Password, serverRandom, hmacWrite)
if E.IsClosedOrCanceled(cErr) && handshakeFinished { if E.IsClosedOrCanceled(cErr) && handshakeFinished {
return nil return nil
} }

View file

@ -2,8 +2,10 @@ package shadowtls
import ( import (
"bytes" "bytes"
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"crypto/tls"
"encoding/binary" "encoding/binary"
"hash" "hash"
"io" "io"
@ -32,26 +34,42 @@ func extractFrame(conn net.Conn) (*buf.Buffer, error) {
return buffer, err return buffer, err
} }
func verifyClientHello(frame []byte, password string) error { func extractServerName(frame []byte) (string, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(bufio.NewReadOnlyConn(bytes.NewReader(frame)), &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = argHello
return nil, nil
},
}).HandshakeContext(context.Background())
if hello != nil {
return hello.ServerName, nil
}
return "", err
}
func verifyClientHello(frame []byte, users []User) (*User, error) {
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize
const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize
if len(frame) < minLen { if len(frame) < minLen {
return io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} else if frame[0] != handshake { } else if frame[0] != handshake {
return E.New("unexpected record type") return nil, E.New("unexpected record type")
} else if frame[5] != clientHello { } else if frame[5] != clientHello {
return E.New("unexpected handshake type") return nil, E.New("unexpected handshake type")
} else if frame[sessionIDLengthIndex] != tlsSessionIDSize { } else if frame[sessionIDLengthIndex] != tlsSessionIDSize {
return E.New("unexpected session id length") return nil, E.New("unexpected session id length")
} }
hmacSHA1Hash := hmac.New(sha1.New, []byte(password)) for _, user := range users {
hmacSHA1Hash := hmac.New(sha1.New, []byte(user.Password))
hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex]) hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex])
hmacSHA1Hash.Write(rw.ZeroBytes[:4]) hmacSHA1Hash.Write(rw.ZeroBytes[:4])
hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:]) hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:])
if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) { if hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) {
return E.New("hmac mismatch") return &user, nil
} }
return nil }
return nil, E.New("hmac mismatch")
} }
func extractServerRandom(frame []byte) []byte { func extractServerRandom(frame []byte) []byte {