mirror of
https://github.com/SagerNet/sing-shadowtls.git
synced 2025-03-31 10:47:35 +03:00
250 lines
7.6 KiB
Go
250 lines
7.6 KiB
Go
package shadowtls
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"net"
|
|
"os"
|
|
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/auth"
|
|
"github.com/sagernet/sing/common/buf"
|
|
"github.com/sagernet/sing/common/bufio"
|
|
"github.com/sagernet/sing/common/debug"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
"github.com/sagernet/sing/common/logger"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
"github.com/sagernet/sing/common/task"
|
|
)
|
|
|
|
type Service struct {
|
|
version int
|
|
password string
|
|
users []User
|
|
handshake HandshakeConfig
|
|
handshakeForServerName map[string]HandshakeConfig
|
|
strictMode bool
|
|
handler N.TCPConnectionHandlerEx
|
|
logger logger.ContextLogger
|
|
}
|
|
|
|
type ServiceConfig struct {
|
|
Version int
|
|
Password string // for protocol version 2
|
|
Users []User // for protocol version 3
|
|
Handshake HandshakeConfig
|
|
HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3
|
|
StrictMode bool // for protocol version 3
|
|
Handler N.TCPConnectionHandlerEx
|
|
Logger logger.ContextLogger
|
|
}
|
|
|
|
type User struct {
|
|
Name string
|
|
Password string
|
|
}
|
|
|
|
type HandshakeConfig struct {
|
|
Server M.Socksaddr
|
|
Dialer N.Dialer
|
|
}
|
|
|
|
func NewService(config ServiceConfig) (*Service, error) {
|
|
service := &Service{
|
|
version: config.Version,
|
|
password: config.Password,
|
|
users: config.Users,
|
|
handshake: config.Handshake,
|
|
handshakeForServerName: config.HandshakeForServerName,
|
|
strictMode: config.StrictMode,
|
|
handler: config.Handler,
|
|
logger: config.Logger,
|
|
}
|
|
|
|
if !service.handshake.Server.IsValid() {
|
|
return nil, E.New("missing default handshake information")
|
|
}
|
|
|
|
if service.handler == nil || service.logger == nil {
|
|
return nil, os.ErrInvalid
|
|
}
|
|
switch config.Version {
|
|
case 1, 2:
|
|
case 3:
|
|
if len(service.users) == 0 {
|
|
return nil, E.New("missing users")
|
|
}
|
|
default:
|
|
return nil, E.New("unknown protocol version: ", config.Version)
|
|
}
|
|
|
|
return service, nil
|
|
}
|
|
|
|
func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig {
|
|
serverName, err := extractServerName(clientHelloFrame.Bytes())
|
|
if err == nil {
|
|
if customHandshake, found := s.handshakeForServerName[serverName]; found {
|
|
return customHandshake
|
|
}
|
|
}
|
|
return s.handshake
|
|
}
|
|
|
|
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) error {
|
|
switch s.version {
|
|
default:
|
|
fallthrough
|
|
case 1:
|
|
handshakeConn, err := s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server)
|
|
if err != nil {
|
|
return E.Cause(err, "server handshake")
|
|
}
|
|
|
|
var group task.Group
|
|
group.Append("client handshake", func(ctx context.Context) error {
|
|
return copyUntilHandshakeFinished(handshakeConn, conn)
|
|
})
|
|
group.Append("server handshake", func(ctx context.Context) error {
|
|
return copyUntilHandshakeFinished(conn, handshakeConn)
|
|
})
|
|
group.FastFail()
|
|
group.Cleanup(func() {
|
|
handshakeConn.Close()
|
|
})
|
|
err = group.Run(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.logger.TraceContext(ctx, "handshake finished")
|
|
s.handler.NewConnectionEx(ctx, conn, source, destination, onClose)
|
|
return nil
|
|
case 2:
|
|
clientHelloFrame, err := extractFrame(conn)
|
|
if err != nil {
|
|
return E.Cause(err, "read client handshake")
|
|
}
|
|
|
|
handshakeConfig := s.selectHandshake(clientHelloFrame)
|
|
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
|
|
if err != nil {
|
|
return E.Cause(err, "server handshake")
|
|
}
|
|
hashConn := newHashWriteConn(conn, s.password)
|
|
go bufio.Copy(hashConn, handshakeConn)
|
|
var request *buf.Buffer
|
|
request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, bufio.NewCachedConn(conn, clientHelloFrame), hashConn, 2)
|
|
if err == nil {
|
|
s.logger.TraceContext(ctx, "handshake finished")
|
|
handshakeConn.Close()
|
|
s.handler.NewConnectionEx(ctx, bufio.NewCachedConn(newConn(conn), request), source, destination, onClose)
|
|
return nil
|
|
} else if err == os.ErrPermission {
|
|
s.logger.WarnContext(ctx, "fallback connection")
|
|
hashConn.Fallback()
|
|
return common.Error(bufio.Copy(handshakeConn, conn))
|
|
} else {
|
|
return err
|
|
}
|
|
case 3:
|
|
clientHelloFrame, err := extractFrame(conn)
|
|
if err != nil {
|
|
return E.Cause(err, "read client handshake")
|
|
}
|
|
|
|
handshakeConfig := s.selectHandshake(clientHelloFrame)
|
|
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
|
|
if err != nil {
|
|
return E.Cause(err, "server handshake")
|
|
}
|
|
|
|
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
|
|
if err != nil {
|
|
clientHelloFrame.Release()
|
|
return E.Cause(err, "write client handshake")
|
|
}
|
|
user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users)
|
|
if err != nil {
|
|
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
|
|
return bufio.CopyConn(ctx, conn, handshakeConn)
|
|
}
|
|
if user.Name != "" {
|
|
ctx = auth.ContextWithUser(ctx, user.Name)
|
|
}
|
|
s.logger.TraceContext(ctx, "client hello verify success")
|
|
clientHelloFrame.Release()
|
|
|
|
var serverHelloFrame *buf.Buffer
|
|
serverHelloFrame, err = extractFrame(handshakeConn)
|
|
if err != nil {
|
|
return E.Cause(err, "read server handshake")
|
|
}
|
|
|
|
_, err = conn.Write(serverHelloFrame.Bytes())
|
|
if err != nil {
|
|
serverHelloFrame.Release()
|
|
return E.Cause(err, "write server handshake")
|
|
}
|
|
|
|
serverRandom := extractServerRandom(serverHelloFrame.Bytes())
|
|
|
|
if serverRandom == nil {
|
|
s.logger.WarnContext(ctx, "server random extract failed, will copy bidirectional")
|
|
return bufio.CopyConn(ctx, conn, handshakeConn)
|
|
}
|
|
|
|
if s.strictMode && !isServerHelloSupportTLS13(serverHelloFrame.Bytes()) {
|
|
s.logger.WarnContext(ctx, "TLS 1.3 is not supported, will copy bidirectional")
|
|
return bufio.CopyConn(ctx, conn, handshakeConn)
|
|
}
|
|
|
|
serverHelloFrame.Release()
|
|
if debug.Enabled {
|
|
s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom))
|
|
}
|
|
hmacWrite := hmac.New(sha1.New, []byte(user.Password))
|
|
hmacWrite.Write(serverRandom)
|
|
hmacAdd := hmac.New(sha1.New, []byte(user.Password))
|
|
hmacAdd.Write(serverRandom)
|
|
hmacAdd.Write([]byte("S"))
|
|
hmacVerify := hmac.New(sha1.New, []byte(user.Password))
|
|
hmacVerifyReset := func() {
|
|
hmacVerify.Reset()
|
|
hmacVerify.Write(serverRandom)
|
|
hmacVerify.Write([]byte("C"))
|
|
}
|
|
|
|
var clientFirstFrame *buf.Buffer
|
|
var group task.Group
|
|
var handshakeFinished bool
|
|
group.Append("client handshake relay", func(ctx context.Context) error {
|
|
clientFrame, cErr := copyByFrameUntilHMACMatches(conn, handshakeConn, hmacVerify, hmacVerifyReset)
|
|
if cErr == nil {
|
|
clientFirstFrame = clientFrame
|
|
handshakeFinished = true
|
|
handshakeConn.Close()
|
|
}
|
|
return cErr
|
|
})
|
|
group.Append("server handshake relay", func(ctx context.Context) error {
|
|
cErr := copyByFrameWithModification(handshakeConn, conn, user.Password, serverRandom, hmacWrite)
|
|
if E.IsClosedOrCanceled(cErr) && handshakeFinished {
|
|
return nil
|
|
}
|
|
return cErr
|
|
})
|
|
group.Cleanup(func() {
|
|
handshakeConn.Close()
|
|
})
|
|
err = group.Run(ctx)
|
|
if err != nil {
|
|
return E.Cause(err, "handshake relay")
|
|
}
|
|
s.logger.TraceContext(ctx, "handshake relay finished")
|
|
s.handler.NewConnectionEx(ctx, bufio.NewCachedConn(newVerifiedConn(conn, hmacAdd, hmacVerify, nil), clientFirstFrame), source, destination, onClose)
|
|
return nil
|
|
}
|
|
}
|