mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Add trojan-local and server command
This commit is contained in:
parent
c66f869581
commit
01eeea9a2e
12 changed files with 634 additions and 48 deletions
|
@ -4,8 +4,6 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -18,7 +16,9 @@ import (
|
|||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/protocol/trojan"
|
||||
transTLS "github.com/sagernet/sing/transport/tls"
|
||||
|
|
|
@ -325,7 +325,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
|
|||
}
|
||||
_payload := buf.StackNew()
|
||||
payload := common.Dup(_payload)
|
||||
err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
err = conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
338
cli/trojan-local/main.go
Normal file
338
cli/trojan-local/main.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
cTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/refraction-networking/utls"
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/redir"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/protocol/trojan"
|
||||
"github.com/sagernet/sing/transport/mixed"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const udpTimeout = 5 * 60
|
||||
|
||||
type flags struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
ServerName string `json:"server_name"`
|
||||
Bind string `json:"local_address"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Verbose bool `json:"verbose"`
|
||||
Insecure bool `json:"insecure"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func main() {
|
||||
f := new(flags)
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "trojan-local",
|
||||
Short: "trojan client",
|
||||
Version: sing.VersionStr,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
run(cmd, f)
|
||||
},
|
||||
}
|
||||
|
||||
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
|
||||
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
command.Flags().BoolVarP(&f.Insecure, "insecure", "i", false, "Set insecure.")
|
||||
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Set verbose mode.")
|
||||
|
||||
err := command.Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, f *flags) {
|
||||
c, err := newClient(f)
|
||||
if err != nil {
|
||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
err = c.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
logrus.Info("mixed server started at ", c.TCPListener.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
c.Close()
|
||||
}
|
||||
|
||||
type client struct {
|
||||
*mixed.Listener
|
||||
server string
|
||||
key [trojan.KeyLength]byte
|
||||
sni string
|
||||
insecure bool
|
||||
dialer net.Dialer
|
||||
}
|
||||
|
||||
func newClient(f *flags) (*client, error) {
|
||||
if f.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && f.Server == "" {
|
||||
f.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
|
||||
f.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.Bind != "" && f.Bind == "" {
|
||||
f.Bind = flagsNew.Bind
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
|
||||
f.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && f.Password == "" {
|
||||
f.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Insecure {
|
||||
f.Insecure = true
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
f.Verbose = true
|
||||
}
|
||||
}
|
||||
|
||||
if f.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
if f.Server == "" {
|
||||
return nil, E.New("missing server address")
|
||||
} else if f.ServerPort == 0 {
|
||||
return nil, E.New("missing server port")
|
||||
}
|
||||
|
||||
c := &client{
|
||||
server: M.AddrPortFrom(M.ParseAddr(f.Server), f.ServerPort).String(),
|
||||
key: trojan.Key(f.Password),
|
||||
sni: f.ServerName,
|
||||
insecure: f.Insecure,
|
||||
}
|
||||
if c.sni == "" {
|
||||
c.sni = f.Server
|
||||
}
|
||||
|
||||
var bind netip.Addr
|
||||
if f.Bind != "" {
|
||||
addr, err := netip.ParseAddr(f.Bind)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "bad local address")
|
||||
}
|
||||
bind = addr
|
||||
} else {
|
||||
bind = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, redir.ModeDisabled, udpTimeout, c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) connect(ctx context.Context) (*cTLS.Conn, error) {
|
||||
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := cTLS.Client(tcpConn, &cTLS.Config{
|
||||
ServerName: c.sni,
|
||||
InsecureSkipVerify: c.insecure,
|
||||
})
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (c *client) connectUTLS(ctx context.Context) (*tls.UConn, error) {
|
||||
tcpConn, err := c.dialer.DialContext(ctx, "tcp", c.server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.UClient(tcpConn, &tls.Config{
|
||||
ServerName: c.sni,
|
||||
InsecureSkipVerify: c.insecure,
|
||||
}, tls.HelloCustom)
|
||||
clientHelloSpec := tls.ClientHelloSpec{
|
||||
CipherSuites: []uint16{
|
||||
tls.GREASE_PLACEHOLDER,
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
tls.DISABLED_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.DISABLED_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.DISABLED_TLS_RSA_WITH_AES_256_CBC_SHA256,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
0xc008,
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
},
|
||||
CompressionMethods: []byte{
|
||||
0x00, // compressionNone
|
||||
},
|
||||
Extensions: []tls.TLSExtension{
|
||||
&tls.UtlsGREASEExtension{},
|
||||
&tls.SNIExtension{},
|
||||
&tls.UtlsExtendedMasterSecretExtension{},
|
||||
&tls.RenegotiationInfoExtension{Renegotiation: tls.RenegotiateOnceAsClient},
|
||||
&tls.SupportedCurvesExtension{Curves: []tls.CurveID{
|
||||
tls.CurveID(tls.GREASE_PLACEHOLDER),
|
||||
tls.X25519,
|
||||
tls.CurveP256,
|
||||
tls.CurveP384,
|
||||
tls.CurveP521,
|
||||
}},
|
||||
&tls.SupportedPointsExtension{SupportedPoints: []byte{
|
||||
0x00, // pointFormatUncompressed
|
||||
}},
|
||||
&tls.ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
|
||||
&tls.StatusRequestExtension{},
|
||||
&tls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []tls.SignatureScheme{
|
||||
tls.ECDSAWithP256AndSHA256,
|
||||
tls.PSSWithSHA256,
|
||||
tls.PKCS1WithSHA256,
|
||||
tls.ECDSAWithP384AndSHA384,
|
||||
tls.ECDSAWithSHA1,
|
||||
tls.PSSWithSHA384,
|
||||
tls.PSSWithSHA384,
|
||||
tls.PKCS1WithSHA384,
|
||||
tls.PSSWithSHA512,
|
||||
tls.PKCS1WithSHA512,
|
||||
tls.PKCS1WithSHA1,
|
||||
}},
|
||||
&tls.SCTExtension{},
|
||||
&tls.KeyShareExtension{KeyShares: []tls.KeyShare{
|
||||
{Group: tls.CurveID(tls.GREASE_PLACEHOLDER), Data: []byte{0}},
|
||||
{Group: tls.X25519},
|
||||
}},
|
||||
&tls.PSKKeyExchangeModesExtension{Modes: []uint8{
|
||||
tls.PskModeDHE,
|
||||
}},
|
||||
&tls.SupportedVersionsExtension{Versions: []uint16{
|
||||
tls.GREASE_PLACEHOLDER,
|
||||
tls.VersionTLS13,
|
||||
tls.VersionTLS12,
|
||||
tls.VersionTLS11,
|
||||
tls.VersionTLS10,
|
||||
}},
|
||||
&tls.UtlsGREASEExtension{},
|
||||
&tls.UtlsPaddingExtension{GetPaddingLen: tls.BoringPaddingStyle},
|
||||
},
|
||||
}
|
||||
err = tlsConn.ApplyPreset(&clientHelloSpec)
|
||||
if err != nil {
|
||||
tcpConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
logrus.Info("outbound ", metadata.Protocol, " TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
|
||||
tlsConn, err := c.connect(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientConn := trojan.NewClientConn(tlsConn, c.key, metadata.Destination)
|
||||
|
||||
err = conn.SetReadDeadline(time.Now().Add(300 * time.Millisecond))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
_, err = request.ReadFrom(conn)
|
||||
if err != nil && !E.IsTimeout(err) {
|
||||
return E.Cause(err, "read payload")
|
||||
}
|
||||
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = clientConn.Write(request.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "client handshake")
|
||||
}
|
||||
|
||||
return rw.CopyConn(ctx, clientConn, conn)
|
||||
}
|
||||
|
||||
func (c *client) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
|
||||
logrus.Info("outbound ", metadata.Protocol, " UDP ", metadata.Source, " ==> ", metadata.Destination)
|
||||
|
||||
tlsConn, err := c.connect(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
/*err = trojan.ClientHandshakeRaw(tlsConn, c.key, trojan.CommandUDP, metadata.Destination, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return socks.CopyPacketConn(ctx, &trojan.PacketConn{Conn: tlsConn}, conn)*/
|
||||
clientConn := trojan.NewClientPacketConn(tlsConn, c.key)
|
||||
return socks.CopyPacketConn(ctx, clientConn, conn)
|
||||
}
|
||||
|
||||
func (c *client) HandleError(err error) {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
192
cli/trojan-server/main.go
Normal file
192
cli/trojan-server/main.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/random"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"github.com/sagernet/sing/protocol/trojan"
|
||||
"github.com/sagernet/sing/transport/tcp"
|
||||
transTLS "github.com/sagernet/sing/transport/tls"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const udpTimeout = 5 * 60
|
||||
|
||||
type flags struct {
|
||||
Server string `json:"server"`
|
||||
ServerPort uint16 `json:"server_port"`
|
||||
ServerName string `json:"server_name"`
|
||||
Bind string `json:"local_address"`
|
||||
LocalPort uint16 `json:"local_port"`
|
||||
Password string `json:"password"`
|
||||
Verbose bool `json:"verbose"`
|
||||
Insecure bool `json:"insecure"`
|
||||
ConfigFile string
|
||||
}
|
||||
|
||||
func main() {
|
||||
f := new(flags)
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "trojan-local",
|
||||
Short: "trojan client",
|
||||
Version: sing.VersionStr,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
run(cmd, f)
|
||||
},
|
||||
}
|
||||
|
||||
command.Flags().StringVarP(&f.Server, "server", "s", "", "Set the server’s hostname or IP.")
|
||||
command.Flags().Uint16VarP(&f.ServerPort, "server-port", "p", 0, "Set the server’s port number.")
|
||||
command.Flags().StringVarP(&f.Bind, "local-address", "b", "", "Set the local address.")
|
||||
command.Flags().Uint16VarP(&f.LocalPort, "local-port", "l", 0, "Set the local port number.")
|
||||
command.Flags().StringVarP(&f.Password, "password", "k", "", "Set the password. The server and the client should use the same password.")
|
||||
command.Flags().BoolVarP(&f.Insecure, "insecure", "i", false, "Set insecure.")
|
||||
command.Flags().StringVarP(&f.ConfigFile, "config", "c", "", "Use a configuration file.")
|
||||
command.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "Set verbose mode.")
|
||||
|
||||
err := command.Execute()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, f *flags) {
|
||||
c, err := newServer(f)
|
||||
if err != nil {
|
||||
logrus.StandardLogger().Log(logrus.FatalLevel, err, "\n\n")
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
err = c.tcpIn.Start()
|
||||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
|
||||
logrus.Info("server started at ", c.tcpIn.Addr())
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)
|
||||
<-osSignals
|
||||
|
||||
c.tcpIn.Close()
|
||||
}
|
||||
|
||||
type server struct {
|
||||
tcpIn *tcp.Listener
|
||||
service trojan.Service[int]
|
||||
}
|
||||
|
||||
func newServer(f *flags) (*server, error) {
|
||||
s := new(server)
|
||||
|
||||
if f.ConfigFile != "" {
|
||||
configFile, err := ioutil.ReadFile(f.ConfigFile)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read config file")
|
||||
}
|
||||
flagsNew := new(flags)
|
||||
err = json.Unmarshal(configFile, flagsNew)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode config file")
|
||||
}
|
||||
if flagsNew.Server != "" && f.Server == "" {
|
||||
f.Server = flagsNew.Server
|
||||
}
|
||||
if flagsNew.ServerPort != 0 && f.ServerPort == 0 {
|
||||
f.ServerPort = flagsNew.ServerPort
|
||||
}
|
||||
if flagsNew.Bind != "" && f.Bind == "" {
|
||||
f.Bind = flagsNew.Bind
|
||||
}
|
||||
if flagsNew.LocalPort != 0 && f.LocalPort == 0 {
|
||||
f.LocalPort = flagsNew.LocalPort
|
||||
}
|
||||
if flagsNew.Password != "" && f.Password == "" {
|
||||
f.Password = flagsNew.Password
|
||||
}
|
||||
if flagsNew.Insecure {
|
||||
f.Insecure = true
|
||||
}
|
||||
if flagsNew.Verbose {
|
||||
f.Verbose = true
|
||||
}
|
||||
}
|
||||
|
||||
if f.Verbose {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
|
||||
if f.Server == "" {
|
||||
return nil, E.New("missing server address")
|
||||
} else if f.ServerPort == 0 {
|
||||
return nil, E.New("missing server port")
|
||||
}
|
||||
|
||||
var bind netip.Addr
|
||||
if f.Server != "" {
|
||||
addr, err := netip.ParseAddr(f.Server)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "bad server address")
|
||||
}
|
||||
bind = addr
|
||||
} else {
|
||||
bind = netip.IPv6Unspecified()
|
||||
}
|
||||
s.service = trojan.NewService[int](s)
|
||||
common.Must(s.service.AddUser(0, f.Password))
|
||||
s.tcpIn = tcp.NewTCPListener(netip.AddrPortFrom(bind, f.ServerPort), s)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
if metadata.Protocol != "trojan" {
|
||||
logrus.Trace("inbound raw TCP from ", metadata.Source)
|
||||
tlsConn := tls.Server(conn, &tls.Config{
|
||||
Rand: random.Blake3KeyedHash(),
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return transTLS.GenerateCertificate(info.ServerName)
|
||||
},
|
||||
})
|
||||
return s.service.NewConnection(ctx, tlsConn, metadata)
|
||||
}
|
||||
destConn, err := network.SystemDialer.DialContext(context.Background(), "tcp", metadata.Destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Info("inbound TCP ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
return rw.CopyConn(ctx, conn, destConn)
|
||||
}
|
||||
|
||||
func (s *server) NewPacketConnection(ctx context.Context, conn socks.PacketConn, metadata M.Metadata) error {
|
||||
logrus.Info("inbound UDP ", metadata.Source, " ==> ", metadata.Destination)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return socks.CopyNetPacketConn(ctx, udpConn, conn)
|
||||
}
|
||||
|
||||
func (s *server) HandleError(err error) {
|
||||
common.Close(err)
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
logrus.Warn(err)
|
||||
}
|
|
@ -94,6 +94,9 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
if b.end+len(data) > b.Cap() {
|
||||
panic("buffer overflow")
|
||||
}
|
||||
n = copy(b.data[b.end:], data)
|
||||
b.end += n
|
||||
return
|
||||
|
@ -102,7 +105,7 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
|
|||
func (b *Buffer) ExtendHeader(size int) []byte {
|
||||
if b.start >= size {
|
||||
b.start -= size
|
||||
return b.data[b.start-size : b.start]
|
||||
return b.data[b.start : b.start+size]
|
||||
} else {
|
||||
/*offset := size - b.start
|
||||
end := b.end + size
|
||||
|
@ -326,7 +329,7 @@ func (b Buffer) Len() int {
|
|||
}
|
||||
|
||||
func (b Buffer) Cap() int {
|
||||
return cap(b.data)
|
||||
return len(b.data)
|
||||
}
|
||||
|
||||
func (b Buffer) Bytes() []byte {
|
||||
|
|
|
@ -62,7 +62,7 @@ func (s *Serializer) AddressLen(addr Addr) int {
|
|||
case AddressFamilyIPv6:
|
||||
return 17
|
||||
default:
|
||||
return 1 + len(addr.Fqdn())
|
||||
return 2 + len(addr.Fqdn())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -44,17 +44,42 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
|||
err := task.Run(ctx, func() error {
|
||||
defer CloseRead(conn)
|
||||
defer CloseWrite(dest)
|
||||
return common.Error(io.Copy(dest, conn))
|
||||
return common.Error(Copy(dest, conn))
|
||||
}, func() error {
|
||||
defer CloseRead(dest)
|
||||
defer CloseWrite(conn)
|
||||
return common.Error(io.Copy(conn, dest))
|
||||
return common.Error(Copy(conn, dest))
|
||||
})
|
||||
conn.Close()
|
||||
dest.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
buffer.FullReset()
|
||||
_, err = buffer.ReadFrom(src)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var cn int
|
||||
cn, err = dst.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n += int64(cn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
|
||||
|
|
1
go.mod
1
go.mod
|
@ -35,6 +35,7 @@ require (
|
|||
github.com/oschwald/maxminddb-golang v1.9.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/refraction-networking/utls v1.1.0 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -44,6 +44,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.1.0 h1:dKXJwSqni/t5csYJ+aQcEgqB7AMWYi6EUc9u3bEmjX0=
|
||||
github.com/refraction-networking/utls v1.1.0/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0=
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
|
|
|
@ -113,16 +113,45 @@ func Key(password string) [KeyLength]byte {
|
|||
return key
|
||||
}
|
||||
|
||||
func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, payload []byte) error {
|
||||
bufferLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5
|
||||
func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination *M.AddrPort, payload []byte) error {
|
||||
_, err := conn.Write(key[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(CRLF[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write([]byte{command})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = socks.AddressSerializer.WriteAddrPort(conn, destination)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(CRLF[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload []byte) error {
|
||||
headerLen := KeyLength + socks.AddressSerializer.AddrPortLen(destination) + 5
|
||||
var header *buf.Buffer
|
||||
var writeHeader bool
|
||||
if len(payload) > 0 && bufferLen+len(payload) < 65535 {
|
||||
buffer := buf.Make(bufferLen + len(payload))
|
||||
copy(buffer[bufferLen:], payload)
|
||||
if len(payload) > 0 && headerLen+len(payload) < 65535 {
|
||||
buffer := buf.Make(headerLen + len(payload))
|
||||
header = buf.With(common.Dup(buffer))
|
||||
} else {
|
||||
buffer := buf.Make(bufferLen)
|
||||
buffer := buf.Make(headerLen)
|
||||
header = buf.With(common.Dup(buffer))
|
||||
writeHeader = true
|
||||
}
|
||||
|
@ -131,11 +160,13 @@ func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, paylo
|
|||
common.Must(header.WriteByte(CommandTCP))
|
||||
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must1(header.Write(payload))
|
||||
|
||||
_, err := conn.Write(header.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write request")
|
||||
}
|
||||
|
||||
if writeHeader {
|
||||
_, err = conn.Write(payload)
|
||||
if err != nil {
|
||||
|
@ -145,14 +176,15 @@ func ClientHandshake(conn net.Conn, key [56]byte, destination *M.AddrPort, paylo
|
|||
return nil
|
||||
}
|
||||
|
||||
func ClientHandshakePacket(conn net.Conn, key [56]byte, destination *M.AddrPort, payload *buf.Buffer) error {
|
||||
bufferLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9
|
||||
func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination *M.AddrPort, payload *buf.Buffer) error {
|
||||
headerLen := KeyLength + 2*socks.AddressSerializer.AddrPortLen(destination) + 9
|
||||
payloadLen := payload.Len()
|
||||
var header *buf.Buffer
|
||||
var writeHeader bool
|
||||
if payload.Start() >= bufferLen {
|
||||
header = buf.With(payload.ExtendHeader(bufferLen))
|
||||
if payload.Start() >= headerLen {
|
||||
header = buf.With(payload.ExtendHeader(headerLen))
|
||||
} else {
|
||||
buffer := buf.Make(bufferLen)
|
||||
buffer := buf.Make(headerLen)
|
||||
header = buf.With(common.Dup(buffer))
|
||||
writeHeader = true
|
||||
}
|
||||
|
@ -162,19 +194,20 @@ func ClientHandshakePacket(conn net.Conn, key [56]byte, destination *M.AddrPort,
|
|||
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
|
||||
common.Must1(header.Write(CRLF))
|
||||
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(payload.Len())))
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
|
||||
common.Must1(header.Write(CRLF))
|
||||
|
||||
_, err := conn.Write(header.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write request")
|
||||
}
|
||||
if writeHeader {
|
||||
_, err = conn.Write(payload.Bytes())
|
||||
_, err := conn.Write(header.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
return E.Cause(err, "write request")
|
||||
}
|
||||
}
|
||||
|
||||
_, err := conn.Write(payload.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write payload")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -207,6 +240,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err
|
|||
headerOverload := socks.AddressSerializer.AddrPortLen(destination) + 4
|
||||
var header *buf.Buffer
|
||||
var writeHeader bool
|
||||
bufferLen := buffer.Len()
|
||||
if buffer.Start() >= headerOverload {
|
||||
header = buf.With(buffer.ExtendHeader(headerOverload))
|
||||
} else {
|
||||
|
@ -214,25 +248,16 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination *M.AddrPort) err
|
|||
_buffer := buf.Make(headerOverload)
|
||||
header = buf.With(common.Dup(_buffer))
|
||||
}
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write socks addr")
|
||||
}
|
||||
err = binary.Write(header, binary.BigEndian, uint16(buffer.Len()))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write chunk length")
|
||||
}
|
||||
_, err = header.Write(CRLF)
|
||||
if err != nil {
|
||||
return E.Cause(err, "write crlf")
|
||||
}
|
||||
common.Must(socks.AddressSerializer.WriteAddrPort(header, destination))
|
||||
common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
|
||||
common.Must1(header.Write(CRLF))
|
||||
if writeHeader {
|
||||
_, err = conn.Write(header.Bytes())
|
||||
_, err := conn.Write(header.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write packet header")
|
||||
}
|
||||
}
|
||||
_, err = conn.Write(buffer.Bytes())
|
||||
_, err := conn.Write(buffer.Bytes())
|
||||
if err != nil {
|
||||
return E.Cause(err, "write packet")
|
||||
}
|
||||
|
|
|
@ -133,20 +133,20 @@ process:
|
|||
if command == CommandTCP {
|
||||
return s.handler.NewConnection(&userCtx, conn, metadata)
|
||||
} else {
|
||||
return s.handler.NewPacketConnection(&userCtx, &packetConn{conn}, metadata)
|
||||
return s.handler.NewPacketConnection(&userCtx, &PacketConn{conn}, metadata)
|
||||
}
|
||||
}
|
||||
|
||||
type packetConn struct {
|
||||
type PacketConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *packetConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
return ReadPacket(c, buffer)
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
return ReadPacket(c.Conn, buffer)
|
||||
}
|
||||
|
||||
func (c *packetConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
return WritePacket(c, buffer, destination)
|
||||
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
return WritePacket(c.Conn, buffer, destination)
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
|
|
|
@ -43,8 +43,8 @@ func GenerateCertificate(hosts ...string) (*tls.Certificate, error) {
|
|||
Subject: pkix.Name{
|
||||
Organization: []string{"Cloudflare, Inc."},
|
||||
},
|
||||
NotBefore: endAt,
|
||||
NotAfter: createAt,
|
||||
NotBefore: createAt,
|
||||
NotAfter: endAt,
|
||||
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue