mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-04 21:17:47 +03:00
feat: traffic logger with disconnect
This commit is contained in:
parent
5b54edd09a
commit
4334d8afb8
4 changed files with 66 additions and 18 deletions
|
@ -12,24 +12,31 @@ import (
|
|||
|
||||
type testTrafficLogger struct {
|
||||
Tx, Rx uint64
|
||||
Block atomic.Bool
|
||||
}
|
||||
|
||||
func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool {
|
||||
atomic.AddUint64(&l.Tx, tx)
|
||||
atomic.AddUint64(&l.Rx, rx)
|
||||
return true
|
||||
return !l.Block.Load()
|
||||
}
|
||||
|
||||
func (l *testTrafficLogger) Get() (tx, rx uint64) {
|
||||
return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx)
|
||||
}
|
||||
|
||||
func (l *testTrafficLogger) SetBlock(block bool) {
|
||||
l.Block.Store(block)
|
||||
}
|
||||
|
||||
func (l *testTrafficLogger) Reset() {
|
||||
atomic.StoreUint64(&l.Tx, 0)
|
||||
atomic.StoreUint64(&l.Rx, 0)
|
||||
}
|
||||
|
||||
// TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly.
|
||||
// More specifically, it tests that the server is correctly logging traffic in both directions,
|
||||
// and that it is correctly disconnecting clients when the traffic logger returns false.
|
||||
func TestServerTrafficLogger(t *testing.T) {
|
||||
tl := &testTrafficLogger{}
|
||||
|
||||
|
@ -147,4 +154,20 @@ func TestServerTrafficLogger(t *testing.T) {
|
|||
if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) {
|
||||
t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx)
|
||||
}
|
||||
|
||||
// Check the disconnect client functionality
|
||||
tl.SetBlock(true)
|
||||
|
||||
// Send and receive TCP data again
|
||||
sData = []byte("1234")
|
||||
_, err = tConn.Write(sData)
|
||||
if err != nil {
|
||||
t.Fatal("error writing to TCP:", err)
|
||||
}
|
||||
// This should fail instantly without reading any data
|
||||
// io.Copy should return nil as EOF is treated as a non-error though
|
||||
n, err := io.Copy(io.Discard, tConn)
|
||||
if n != 0 || err != nil {
|
||||
t.Fatal("expected 0 bytes read and nil error, got", n, err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -180,5 +180,5 @@ type EventLogger interface {
|
|||
// bandwidth limits or post-connection authentication, for example.
|
||||
// The implementation of this interface must be thread-safe.
|
||||
type TrafficLogger interface {
|
||||
Log(id string, tx, rx uint64) bool
|
||||
Log(id string, tx, rx uint64) (ok bool)
|
||||
}
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
package server
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error {
|
||||
var errDisconnect = errors.New("traffic logger requested disconnect")
|
||||
|
||||
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
log(uint64(nw))
|
||||
if !log(uint64(nr)) {
|
||||
// Log returns false, which means that the client should be disconnected
|
||||
return errDisconnect
|
||||
}
|
||||
_, ew := dst.Write(buf[0:nr])
|
||||
if ew != nil {
|
||||
return ew
|
||||
}
|
||||
|
@ -28,13 +34,13 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error {
|
|||
func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error {
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) {
|
||||
l.Log(id, 0, n)
|
||||
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool {
|
||||
return l.Log(id, 0, n)
|
||||
})
|
||||
}()
|
||||
go func() {
|
||||
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) {
|
||||
l.Log(id, n, 0)
|
||||
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool {
|
||||
return l.Log(id, n, 0)
|
||||
})
|
||||
}()
|
||||
// Block until one of the two goroutines returns
|
||||
|
|
|
@ -279,6 +279,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
// Cleanup
|
||||
_ = tConn.Close()
|
||||
_ = stream.Close()
|
||||
// Disconnect the client if TrafficLogger requested
|
||||
if err == errDisconnect {
|
||||
_ = h.conn.CloseWithError(0, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
|
||||
|
@ -316,7 +320,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
|
|||
udpN, rAddr, err := conn.ReadFrom(udpBuf)
|
||||
if udpN > 0 {
|
||||
if h.config.TrafficLogger != nil {
|
||||
h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
|
||||
ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = h.conn.CloseWithError(0, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try no frag first
|
||||
msg := protocol.UDPMessage{
|
||||
|
@ -371,20 +380,30 @@ func (h *h3sHandler) udpLoop() {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.handleUDPMessage(msg)
|
||||
ok := h.handleUDPMessage(msg)
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = h.conn.CloseWithError(0, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// client -> remote direction
|
||||
func (h *h3sHandler) handleUDPMessage(msg []byte) {
|
||||
// Returns a bool indicating whether the receiving loop should continue
|
||||
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
return
|
||||
return true
|
||||
}
|
||||
n, _ := h.udpSM.Feed(udpMsg)
|
||||
if n > 0 && h.config.TrafficLogger != nil {
|
||||
h.config.TrafficLogger.Log(h.authID, uint64(n), 0)
|
||||
if h.config.TrafficLogger != nil {
|
||||
ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, _ = h.udpSM.Feed(udpMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue