feat: traffic logger with disconnect

This commit is contained in:
tobyxdd 2023-06-08 20:55:08 -07:00
parent 5b54edd09a
commit 4334d8afb8
4 changed files with 66 additions and 18 deletions

View file

@ -12,24 +12,31 @@ import (
type testTrafficLogger struct {
Tx, Rx uint64
Block atomic.Bool
}
func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool {
atomic.AddUint64(&l.Tx, tx)
atomic.AddUint64(&l.Rx, rx)
return true
return !l.Block.Load()
}
func (l *testTrafficLogger) Get() (tx, rx uint64) {
return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx)
}
func (l *testTrafficLogger) SetBlock(block bool) {
l.Block.Store(block)
}
func (l *testTrafficLogger) Reset() {
atomic.StoreUint64(&l.Tx, 0)
atomic.StoreUint64(&l.Rx, 0)
}
// TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly.
// More specifically, it tests that the server is correctly logging traffic in both directions,
// and that it is correctly disconnecting clients when the traffic logger returns false.
func TestServerTrafficLogger(t *testing.T) {
tl := &testTrafficLogger{}
@ -147,4 +154,20 @@ func TestServerTrafficLogger(t *testing.T) {
if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) {
t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx)
}
// Check the disconnect client functionality
tl.SetBlock(true)
// Send and receive TCP data again
sData = []byte("1234")
_, err = tConn.Write(sData)
if err != nil {
t.Fatal("error writing to TCP:", err)
}
// This should fail instantly without reading any data
// io.Copy should return nil as EOF is treated as a non-error though
n, err := io.Copy(io.Discard, tConn)
if n != 0 || err != nil {
t.Fatal("expected 0 bytes read and nil error, got", n, err)
}
}

View file

@ -180,5 +180,5 @@ type EventLogger interface {
// bandwidth limits or post-connection authentication, for example.
// The implementation of this interface must be thread-safe.
type TrafficLogger interface {
Log(id string, tx, rx uint64) bool
Log(id string, tx, rx uint64) (ok bool)
}

View file

@ -1,16 +1,22 @@
package server
import "io"
import (
"errors"
"io"
)
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error {
var errDisconnect = errors.New("traffic logger requested disconnect")
func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
log(uint64(nw))
if !log(uint64(nr)) {
// Log returns false, which means that the client should be disconnected
return errDisconnect
}
_, ew := dst.Write(buf[0:nr])
if ew != nil {
return ew
}
@ -28,13 +34,13 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error {
func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error {
errChan := make(chan error, 2)
go func() {
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) {
l.Log(id, 0, n)
errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool {
return l.Log(id, 0, n)
})
}()
go func() {
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) {
l.Log(id, n, 0)
errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool {
return l.Log(id, n, 0)
})
}()
// Block until one of the two goroutines returns

View file

@ -279,6 +279,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
// Cleanup
_ = tConn.Close()
_ = stream.Close()
// Disconnect the client if TrafficLogger requested
if err == errDisconnect {
_ = h.conn.CloseWithError(0, "")
}
}
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
@ -316,7 +320,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
udpN, rAddr, err := conn.ReadFrom(udpBuf)
if udpN > 0 {
if h.config.TrafficLogger != nil {
h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(0, "")
return
}
}
// Try no frag first
msg := protocol.UDPMessage{
@ -371,20 +380,30 @@ func (h *h3sHandler) udpLoop() {
if err != nil {
return
}
h.handleUDPMessage(msg)
ok := h.handleUDPMessage(msg)
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(0, "")
return
}
}
}
// client -> remote direction
func (h *h3sHandler) handleUDPMessage(msg []byte) {
// Returns a bool indicating whether the receiving loop should continue
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
return
return true
}
n, _ := h.udpSM.Feed(udpMsg)
if n > 0 && h.config.TrafficLogger != nil {
h.config.TrafficLogger.Log(h.authID, uint64(n), 0)
if h.config.TrafficLogger != nil {
ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
if !ok {
return false
}
}
_, _ = h.udpSM.Feed(udpMsg)
return true
}
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {