hysteria/pkg/forwarder/server.go

119 lines
3.4 KiB
Go

package forwarder
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"github.com/tobyxdd/hysteria/internal/forwarder"
"math/big"
"net"
)
type server struct {
config ServerConfig
callbacks ServerCallbacks
entries map[string]*forwarder.QUICServer
}
func NewServer(config ServerConfig, callbacks ServerCallbacks) Server {
// Fix config first
if config.TLSConfig == nil {
config.TLSConfig = generateInsecureTLSConfig()
}
if config.MaxSpeedPerClient == nil {
config.MaxSpeedPerClient = &Speed{0, 0}
}
if config.MaxReceiveWindowPerConnection == 0 {
config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn
}
if config.MaxReceiveWindowPerClient == 0 {
config.MaxReceiveWindowPerClient = defaultReceiveWindow
}
if config.MaxConnectionPerClient <= 0 {
config.MaxConnectionPerClient = defaultMaxClientConn
}
return &server{config: config, callbacks: callbacks, entries: make(map[string]*forwarder.QUICServer)}
}
func (s *server) Add(listenAddr, remoteAddr string) error {
qs, err := forwarder.NewQUICServer(listenAddr, remoteAddr, s.config.BannerMessage, s.config.TLSConfig,
s.config.MaxSpeedPerClient.SendBPS, s.config.MaxSpeedPerClient.ReceiveBPS,
s.config.MaxReceiveWindowPerConnection, s.config.MaxReceiveWindowPerClient,
s.config.MaxConnectionPerClient, forwarder.CongestionFactory(s.config.CongestionFactory),
func(addr net.Addr, name string, sSend uint64, sRecv uint64) {
if s.callbacks.ClientConnectedCallback != nil {
s.callbacks.ClientConnectedCallback(listenAddr, addr, name, sSend, sRecv)
}
},
func(addr net.Addr, name string, err error) {
if s.callbacks.ClientDisconnectedCallback != nil {
s.callbacks.ClientDisconnectedCallback(listenAddr, addr, name, err)
}
},
func(addr net.Addr, name string, id int) {
if s.callbacks.ClientNewStreamCallback != nil {
s.callbacks.ClientNewStreamCallback(listenAddr, addr, name, id)
}
},
func(addr net.Addr, name string, id int, err error) {
if s.callbacks.ClientStreamClosedCallback != nil {
s.callbacks.ClientStreamClosedCallback(listenAddr, addr, name, id, err)
}
},
func(remoteAddr string, err error) {
if s.callbacks.TCPErrorCallback != nil {
s.callbacks.TCPErrorCallback(listenAddr, remoteAddr, err)
}
},
)
if err != nil {
return err
}
s.entries[listenAddr] = qs
return nil
}
func (s *server) Remove(listenAddr string) error {
defer delete(s.entries, listenAddr)
if qs, ok := s.entries[listenAddr]; ok && qs != nil {
return qs.Close()
}
return nil
}
func (s *server) Stats() map[string]Stats {
r := make(map[string]Stats, len(s.entries))
for laddr, sv := range s.entries {
addr, in, out := sv.Stats()
r[laddr] = Stats{
RemoteAddr: addr,
inboundBytes: in,
outboundBytes: out,
}
}
return r
}
func generateInsecureTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{TLSAppProtocol},
}
}