feat: traffic logger with disconnect

This commit is contained in:
tobyxdd 2023-06-08 20:55:08 -07:00
parent 5b54edd09a
commit 4334d8afb8
4 changed files with 66 additions and 18 deletions

View file

@ -12,24 +12,31 @@ import (
type testTrafficLogger struct { type testTrafficLogger struct {
Tx, Rx uint64 Tx, Rx uint64
Block atomic.Bool
} }
func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool { func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool {
atomic.AddUint64(&l.Tx, tx) atomic.AddUint64(&l.Tx, tx)
atomic.AddUint64(&l.Rx, rx) atomic.AddUint64(&l.Rx, rx)
return true return !l.Block.Load()
} }
func (l *testTrafficLogger) Get() (tx, rx uint64) { func (l *testTrafficLogger) Get() (tx, rx uint64) {
return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx) return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx)
} }
func (l *testTrafficLogger) SetBlock(block bool) {
l.Block.Store(block)
}
func (l *testTrafficLogger) Reset() { func (l *testTrafficLogger) Reset() {
atomic.StoreUint64(&l.Tx, 0) atomic.StoreUint64(&l.Tx, 0)
atomic.StoreUint64(&l.Rx, 0) atomic.StoreUint64(&l.Rx, 0)
} }
// TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly. // TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly.
// More specifically, it tests that the server is correctly logging traffic in both directions,
// and that it is correctly disconnecting clients when the traffic logger returns false.
func TestServerTrafficLogger(t *testing.T) { func TestServerTrafficLogger(t *testing.T) {
tl := &testTrafficLogger{} tl := &testTrafficLogger{}
@ -147,4 +154,20 @@ func TestServerTrafficLogger(t *testing.T) {
if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) { if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) {
t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx) t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx)
} }
// Check the disconnect client functionality
tl.SetBlock(true)
// Send and receive TCP data again
sData = []byte("1234")
_, err = tConn.Write(sData)
if err != nil {
t.Fatal("error writing to TCP:", err)
}
// This should fail instantly without reading any data
// io.Copy should return nil as EOF is treated as a non-error though
n, err := io.Copy(io.Discard, tConn)
if n != 0 || err != nil {
t.Fatal("expected 0 bytes read and nil error, got", n, err)
}
} }

View file

@ -180,5 +180,5 @@ type EventLogger interface {
// bandwidth limits or post-connection authentication, for example. // bandwidth limits or post-connection authentication, for example.
// The implementation of this interface must be thread-safe. // The implementation of this interface must be thread-safe.
type TrafficLogger interface { type TrafficLogger interface {
Log(id string, tx, rx uint64) bool Log(id string, tx, rx uint64) (ok bool)
} }

View file

@ -1,16 +1,22 @@
package server package server
import "io" import (
"errors"
"io"
)
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error { var errDisconnect = errors.New("traffic logger requested disconnect")
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error {
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
for { for {
nr, er := src.Read(buf) nr, er := src.Read(buf)
if nr > 0 { if nr > 0 {
nw, ew := dst.Write(buf[0:nr]) if !log(uint64(nr)) {
if nw > 0 { // Log returns false, which means that the client should be disconnected
log(uint64(nw)) return errDisconnect
} }
_, ew := dst.Write(buf[0:nr])
if ew != nil { if ew != nil {
return ew return ew
} }
@ -28,13 +34,13 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error {
func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error { func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error {
errChan := make(chan error, 2) errChan := make(chan error, 2)
go func() { go func() {
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) { errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool {
l.Log(id, 0, n) return l.Log(id, 0, n)
}) })
}() }()
go func() { go func() {
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) { errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool {
l.Log(id, n, 0) return l.Log(id, n, 0)
}) })
}() }()
// Block until one of the two goroutines returns // Block until one of the two goroutines returns

View file

@ -279,6 +279,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
// Cleanup // Cleanup
_ = tConn.Close() _ = tConn.Close()
_ = stream.Close() _ = stream.Close()
// Disconnect the client if TrafficLogger requested
if err == errDisconnect {
_ = h.conn.CloseWithError(0, "")
}
} }
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
@ -316,7 +320,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
udpN, rAddr, err := conn.ReadFrom(udpBuf) udpN, rAddr, err := conn.ReadFrom(udpBuf)
if udpN > 0 { if udpN > 0 {
if h.config.TrafficLogger != nil { if h.config.TrafficLogger != nil {
h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(0, "")
return
}
} }
// Try no frag first // Try no frag first
msg := protocol.UDPMessage{ msg := protocol.UDPMessage{
@ -371,20 +380,30 @@ func (h *h3sHandler) udpLoop() {
if err != nil { if err != nil {
return return
} }
h.handleUDPMessage(msg) ok := h.handleUDPMessage(msg)
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(0, "")
return
}
} }
} }
// client -> remote direction // client -> remote direction
func (h *h3sHandler) handleUDPMessage(msg []byte) { // Returns a bool indicating whether the receiving loop should continue
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
udpMsg, err := protocol.ParseUDPMessage(msg) udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil { if err != nil {
return return true
} }
n, _ := h.udpSM.Feed(udpMsg) if h.config.TrafficLogger != nil {
if n > 0 && h.config.TrafficLogger != nil { ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
h.config.TrafficLogger.Log(h.authID, uint64(n), 0) if !ok {
return false
}
} }
_, _ = h.udpSM.Feed(udpMsg)
return true
} }
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) { func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {