mirror of
https://github.com/SagerNet/sing-shadowtls.git
synced 2025-04-03 12:17:36 +03:00
Add multiple server names and multi-user
This commit is contained in:
parent
e986e9cd9e
commit
9cb3e9e0ed
2 changed files with 115 additions and 44 deletions
117
service.go
117
service.go
|
@ -9,6 +9,7 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
|
@ -20,12 +21,23 @@ import (
|
|||
)
|
||||
|
||||
type ServiceConfig struct {
|
||||
Version int
|
||||
Password string
|
||||
HandshakeServer M.Socksaddr
|
||||
HandshakeDialer N.Dialer
|
||||
Handler Handler
|
||||
Logger logger.ContextLogger
|
||||
Version int
|
||||
Password string // for protocol version 2
|
||||
Users []User // for protocol version 3
|
||||
Handshake HandshakeConfig
|
||||
HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3
|
||||
Handler Handler
|
||||
Logger logger.ContextLogger
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Password string
|
||||
}
|
||||
|
||||
type HandshakeConfig struct {
|
||||
Server M.Socksaddr
|
||||
Dialer N.Dialer
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
|
@ -34,28 +46,39 @@ type Handler interface {
|
|||
}
|
||||
|
||||
type Service struct {
|
||||
version int
|
||||
password string
|
||||
handshakeServer M.Socksaddr
|
||||
handshakeDialer N.Dialer
|
||||
handler Handler
|
||||
logger logger.ContextLogger
|
||||
version int
|
||||
password string
|
||||
users []User
|
||||
handshake HandshakeConfig
|
||||
handshakeForServerName map[string]HandshakeConfig
|
||||
handler Handler
|
||||
logger logger.ContextLogger
|
||||
}
|
||||
|
||||
func NewService(config ServiceConfig) (*Service, error) {
|
||||
service := &Service{
|
||||
version: config.Version,
|
||||
password: config.Password,
|
||||
handshakeServer: config.HandshakeServer,
|
||||
handshakeDialer: config.HandshakeDialer,
|
||||
handler: config.Handler,
|
||||
logger: config.Logger,
|
||||
version: config.Version,
|
||||
password: config.Password,
|
||||
users: config.Users,
|
||||
handshake: config.Handshake,
|
||||
handshakeForServerName: config.HandshakeForServerName,
|
||||
handler: config.Handler,
|
||||
logger: config.Logger,
|
||||
}
|
||||
if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil {
|
||||
|
||||
if !service.handshake.Server.IsValid() {
|
||||
return nil, E.New("missing default handshake information")
|
||||
}
|
||||
|
||||
if service.handler == nil || service.logger == nil {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
switch config.Version {
|
||||
case 1, 2, 3:
|
||||
case 1, 2:
|
||||
case 3:
|
||||
if len(service.users) == 0 {
|
||||
return nil, E.New("missing users")
|
||||
}
|
||||
default:
|
||||
return nil, E.New("unknown protocol version: ", config.Version)
|
||||
}
|
||||
|
@ -63,15 +86,26 @@ func NewService(config ServiceConfig) (*Service, error) {
|
|||
return service, nil
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer)
|
||||
if err != nil {
|
||||
return E.Cause(err, "server handshake")
|
||||
func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig {
|
||||
serverName, err := extractServerName(clientHelloFrame.Bytes())
|
||||
if err == nil {
|
||||
if customHandshake, found := s.handshakeForServerName[serverName]; found {
|
||||
return customHandshake
|
||||
}
|
||||
}
|
||||
return s.handshake
|
||||
}
|
||||
|
||||
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
switch s.version {
|
||||
default:
|
||||
fallthrough
|
||||
case 1:
|
||||
handshakeConn, err := s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server)
|
||||
if err != nil {
|
||||
return E.Cause(err, "server handshake")
|
||||
}
|
||||
|
||||
var group task.Group
|
||||
group.Append("client handshake", func(ctx context.Context) error {
|
||||
return copyUntilHandshakeFinished(handshakeConn, conn)
|
||||
|
@ -90,10 +124,20 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
s.logger.TraceContext(ctx, "handshake finished")
|
||||
return s.handler.NewConnection(ctx, conn, metadata)
|
||||
case 2:
|
||||
clientHelloFrame, err := extractFrame(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read client handshake")
|
||||
}
|
||||
|
||||
handshakeConfig := s.selectHandshake(clientHelloFrame)
|
||||
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
|
||||
if err != nil {
|
||||
return E.Cause(err, "server handshake")
|
||||
}
|
||||
hashConn := newHashWriteConn(conn, s.password)
|
||||
go bufio.Copy(hashConn, handshakeConn)
|
||||
var request *buf.Buffer
|
||||
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2)
|
||||
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, bufio.NewCachedConn(conn, clientHelloFrame), hashConn, 2)
|
||||
if err == nil {
|
||||
s.logger.TraceContext(ctx, "handshake finished")
|
||||
handshakeConn.Close()
|
||||
|
@ -106,21 +150,30 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return err
|
||||
}
|
||||
case 3:
|
||||
var clientHelloFrame *buf.Buffer
|
||||
clientHelloFrame, err = extractFrame(conn)
|
||||
clientHelloFrame, err := extractFrame(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read client handshake")
|
||||
}
|
||||
|
||||
handshakeConfig := s.selectHandshake(clientHelloFrame)
|
||||
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
|
||||
if err != nil {
|
||||
return E.Cause(err, "server handshake")
|
||||
}
|
||||
|
||||
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
|
||||
if err != nil {
|
||||
clientHelloFrame.Release()
|
||||
return E.Cause(err, "write client handshake")
|
||||
}
|
||||
err = verifyClientHello(clientHelloFrame.Bytes(), s.password)
|
||||
user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users)
|
||||
if err != nil {
|
||||
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
|
||||
return bufio.CopyConn(ctx, conn, handshakeConn)
|
||||
}
|
||||
if user.Name != "" {
|
||||
ctx = auth.ContextWithUser(ctx, user.Name)
|
||||
}
|
||||
s.logger.TraceContext(ctx, "client hello verify success")
|
||||
clientHelloFrame.Release()
|
||||
|
||||
|
@ -152,12 +205,12 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
if debug.Enabled {
|
||||
s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom))
|
||||
}
|
||||
hmacWrite := hmac.New(sha1.New, []byte(s.password))
|
||||
hmacWrite := hmac.New(sha1.New, []byte(user.Password))
|
||||
hmacWrite.Write(serverRandom)
|
||||
hmacAdd := hmac.New(sha1.New, []byte(s.password))
|
||||
hmacAdd := hmac.New(sha1.New, []byte(user.Password))
|
||||
hmacAdd.Write(serverRandom)
|
||||
hmacAdd.Write([]byte("S"))
|
||||
hmacVerify := hmac.New(sha1.New, []byte(s.password))
|
||||
hmacVerify := hmac.New(sha1.New, []byte(user.Password))
|
||||
hmacVerifyReset := func() {
|
||||
hmacVerify.Reset()
|
||||
hmacVerify.Write(serverRandom)
|
||||
|
@ -177,7 +230,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return cErr
|
||||
})
|
||||
group.Append("server handshake relay", func(ctx context.Context) error {
|
||||
cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite)
|
||||
cErr := copyByFrameWithModification(handshakeConn, conn, user.Password, serverRandom, hmacWrite)
|
||||
if E.IsClosedOrCanceled(cErr) && handshakeFinished {
|
||||
return nil
|
||||
}
|
||||
|
|
42
v3_server.go
42
v3_server.go
|
@ -2,8 +2,10 @@ package shadowtls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"hash"
|
||||
"io"
|
||||
|
@ -32,26 +34,42 @@ func extractFrame(conn net.Conn) (*buf.Buffer, error) {
|
|||
return buffer, err
|
||||
}
|
||||
|
||||
func verifyClientHello(frame []byte, password string) error {
|
||||
func extractServerName(frame []byte) (string, error) {
|
||||
var hello *tls.ClientHelloInfo
|
||||
err := tls.Server(bufio.NewReadOnlyConn(bytes.NewReader(frame)), &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).HandshakeContext(context.Background())
|
||||
if hello != nil {
|
||||
return hello.ServerName, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
func verifyClientHello(frame []byte, users []User) (*User, error) {
|
||||
const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize
|
||||
const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize
|
||||
if len(frame) < minLen {
|
||||
return io.ErrUnexpectedEOF
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
} else if frame[0] != handshake {
|
||||
return E.New("unexpected record type")
|
||||
return nil, E.New("unexpected record type")
|
||||
} else if frame[5] != clientHello {
|
||||
return E.New("unexpected handshake type")
|
||||
return nil, E.New("unexpected handshake type")
|
||||
} else if frame[sessionIDLengthIndex] != tlsSessionIDSize {
|
||||
return E.New("unexpected session id length")
|
||||
return nil, E.New("unexpected session id length")
|
||||
}
|
||||
hmacSHA1Hash := hmac.New(sha1.New, []byte(password))
|
||||
hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex])
|
||||
hmacSHA1Hash.Write(rw.ZeroBytes[:4])
|
||||
hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:])
|
||||
if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) {
|
||||
return E.New("hmac mismatch")
|
||||
for _, user := range users {
|
||||
hmacSHA1Hash := hmac.New(sha1.New, []byte(user.Password))
|
||||
hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex])
|
||||
hmacSHA1Hash.Write(rw.ZeroBytes[:4])
|
||||
hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:])
|
||||
if hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) {
|
||||
return &user, nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, E.New("hmac mismatch")
|
||||
}
|
||||
|
||||
func extractServerRandom(frame []byte) []byte {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue