hysteria/internal/forwarder/server.go

176 lines
5.1 KiB
Go

package forwarder
import (
"context"
"crypto/tls"
"errors"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"net"
"sync/atomic"
)
type QUICServer struct {
inboundBytes, outboundBytes uint64 // atomic
listener quic.Listener
remoteAddr string
banner string
sendBPS, recvBPS uint64
newCongestion CongestionFactory
onClientConnected ClientConnectedCallback
onClientDisconnected ClientDisconnectedCallback
onClientNewStream ClientNewStreamCallback
onClientStreamClosed ClientStreamClosedCallback
onTCPError TCPErrorCallback
}
func NewQUICServer(addr string, remoteAddr string, banner string, tlsConfig *tls.Config,
sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindowClients uint64,
clientMaxConn int, newCongestion CongestionFactory,
onClientConnected ClientConnectedCallback,
onClientDisconnected ClientDisconnectedCallback,
onClientNewStream ClientNewStreamCallback,
onClientStreamClosed ClientStreamClosedCallback,
onTCPError TCPErrorCallback) (*QUICServer, error) {
listener, err := quic.ListenAddr(addr, tlsConfig, &quic.Config{
MaxReceiveStreamFlowControlWindow: recvWindowConn,
MaxReceiveConnectionFlowControlWindow: recvWindowClients,
MaxIncomingStreams: clientMaxConn,
KeepAlive: true,
})
if err != nil {
return nil, err
}
s := &QUICServer{
listener: listener,
remoteAddr: remoteAddr,
banner: banner,
sendBPS: sendBPS,
recvBPS: recvBPS,
newCongestion: newCongestion,
onClientConnected: onClientConnected,
onClientDisconnected: onClientDisconnected,
onClientNewStream: onClientNewStream,
onClientStreamClosed: onClientStreamClosed,
onTCPError: onTCPError,
}
go s.acceptLoop()
return s, nil
}
func (s *QUICServer) Close() error {
return s.listener.Close()
}
func (s *QUICServer) Stats() (string, uint64, uint64) {
return s.remoteAddr, atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes)
}
func (s *QUICServer) acceptLoop() {
for {
cs, err := s.listener.Accept(context.Background())
if err != nil {
break
}
go s.handleClient(cs)
}
}
func (s *QUICServer) handleClient(cs quic.Session) {
// Expect the client to create a control stream and send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
return
}
name, sSend, sRecv, err := s.handleControlStream(cs, ctlStream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return
}
// Only after a successful exchange of information do we consider this a valid client
s.onClientConnected(cs.RemoteAddr(), name, sSend, sRecv)
// Start accepting streams to be forwarded
var closeErr error
for {
stream, err := cs.AcceptStream(context.Background())
if err != nil {
closeErr = err
break
}
go s.handleStream(cs.RemoteAddr(), name, stream)
}
s.onClientDisconnected(cs.RemoteAddr(), name, closeErr)
_ = cs.CloseWithError(closeErrorCodeGeneric, "generic")
}
// Negotiate speed & return client name
func (s *QUICServer) handleControlStream(cs quic.Session, stream quic.Stream) (string, uint64, uint64, error) {
req, err := readClientSpeedRequest(stream)
if err != nil {
return "", 0, 0, err
}
if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 {
return "", 0, 0, errors.New("incorrect speed information provided by the client")
}
limited := false
serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
limited = true
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS {
limited = true
serverReceiveBPS = s.recvBPS
}
// Response
err = writeServerSpeedResponse(stream, &ServerSpeedResponse{
Banner: s.banner,
Limited: limited,
Limit: &Speed{
SendBps: s.sendBPS,
ReceiveBps: s.recvBPS,
},
Speed: &Speed{
SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS,
},
})
if err != nil {
return "", 0, 0, err
}
// Set the congestion accordingly
if s.newCongestion != nil {
cs.SetCongestion(s.newCongestion(serverSendBPS))
}
return req.Name, serverSendBPS, serverReceiveBPS, nil
}
func (s *QUICServer) handleStream(addr net.Addr, name string, stream quic.Stream) {
s.onClientNewStream(addr, name, int(stream.StreamID()))
defer stream.Close()
tcpConn, err := net.Dial("tcp", s.remoteAddr)
if err != nil {
s.onTCPError(s.remoteAddr, err)
s.onClientStreamClosed(addr, name, int(stream.StreamID()), err)
return
}
defer tcpConn.Close()
// Pipes
errChan := make(chan error, 2)
go func() {
// TCP to QUIC
errChan <- utils.Pipe(tcpConn, stream, &s.outboundBytes)
}()
go func() {
// QUIC to TCP
errChan <- utils.Pipe(stream, tcpConn, &s.inboundBytes)
}()
// We only need the first error
err = <-errChan
s.onClientStreamClosed(addr, name, int(stream.StreamID()), err)
}