feat(core): server RequestHook support

This commit is contained in:
Toby 2024-06-15 14:15:56 -07:00
parent 4c2a905892
commit feacb1f85e
13 changed files with 416 additions and 19 deletions

View file

@ -8,6 +8,7 @@ import (
"github.com/apernet/hysteria/core/v2/errors"
"github.com/apernet/hysteria/core/v2/internal/pmtud"
"github.com/apernet/quic-go"
)
const (
@ -22,6 +23,7 @@ type Config struct {
TLSConfig TLSConfig
QUICConfig QUICConfig
Conn net.PacketConn
RequestHook RequestHook
Outbound Outbound
BandwidthConfig BandwidthConfig
IgnoreClientBandwidth bool
@ -110,6 +112,17 @@ type QUICConfig struct {
DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms.
}
// RequestHook allows filtering and modifying requests before the server connects to the remote.
// The returned byte slice, if not empty, will be sent to the remote before proxying - this is
// mainly for "putting back" the content read from the client for sniffing, etc.
// Return a non-nil error to abort the connection.
// Note that due to the current architectural limitations, it can only inspect the first packet
// of a UDP connection. It also cannot put back any data as the first packet is always sent as-is.
type RequestHook interface {
TCP(stream quic.Stream, reqAddr *string) ([]byte, error)
UDP(data []byte, reqAddr *string) error
}
// Outbound provides the implementation of how the server should connect to remote servers.
// Although UDP includes a reqAddr, the implementation does not necessarily have to use it
// to make a "connected" UDP connection that does not accept packets from other addresses.

View file

@ -20,6 +20,53 @@ func (_m *mockUDPIO) EXPECT() *mockUDPIO_Expecter {
return &mockUDPIO_Expecter{mock: &_m.Mock}
}
// Hook provides a mock function with given fields: data, reqAddr
func (_m *mockUDPIO) Hook(data []byte, reqAddr *string) error {
ret := _m.Called(data, reqAddr)
if len(ret) == 0 {
panic("no return value specified for Hook")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, *string) error); ok {
r0 = rf(data, reqAddr)
} else {
r0 = ret.Error(0)
}
return r0
}
// mockUDPIO_Hook_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Hook'
type mockUDPIO_Hook_Call struct {
*mock.Call
}
// Hook is a helper method to define mock.On call
// - data []byte
// - reqAddr *string
func (_e *mockUDPIO_Expecter) Hook(data interface{}, reqAddr interface{}) *mockUDPIO_Hook_Call {
return &mockUDPIO_Hook_Call{Call: _e.mock.On("Hook", data, reqAddr)}
}
func (_c *mockUDPIO_Hook_Call) Run(run func(data []byte, reqAddr *string)) *mockUDPIO_Hook_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].(*string))
})
return _c
}
func (_c *mockUDPIO_Hook_Call) Return(_a0 error) *mockUDPIO_Hook_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *mockUDPIO_Hook_Call) RunAndReturn(run func([]byte, *string) error) *mockUDPIO_Hook_Call {
_c.Call.Return(run)
return _c
}
// ReceiveMessage provides a mock function with given fields:
func (_m *mockUDPIO) ReceiveMessage() (*protocol.UDPMessage, error) {
ret := _m.Called()

View file

@ -170,7 +170,7 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !h.config.DisableUDP {
go func() {
sm := newUDPSessionManager(
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.RequestHook, h.config.Outbound},
&udpEventLoggerImpl{h.conn, id, h.config.EventLogger},
h.config.UDPIdleTimeout)
h.udpSM = sm
@ -211,6 +211,16 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
_ = stream.Close()
return
}
// Call the hook if set
var putback []byte
if h.config.RequestHook != nil {
putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
if err != nil {
_ = protocol.WriteTCPResponse(stream, false, err.Error())
_ = stream.Close()
return
}
}
// Log the event
if h.config.EventLogger != nil {
h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr)
@ -227,6 +237,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
return
}
_ = protocol.WriteTCPResponse(stream, true, "")
// Put back the data if the hook requested
if len(putback) > 0 {
_, _ = tConn.Write(putback)
}
// Start proxying
if h.config.TrafficLogger != nil {
err = copyTwoWayWithLogger(h.authID, stream, tConn, h.config.TrafficLogger)
@ -260,6 +274,7 @@ type udpIOImpl struct {
Conn quic.Connection
AuthID string
TrafficLogger TrafficLogger
RequestHook RequestHook
Outbound Outbound
}
@ -304,6 +319,13 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
return io.Conn.SendDatagram(buf[:msgN])
}
func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error {
if io.RequestHook != nil {
return io.RequestHook.UDP(data, reqAddr)
}
return nil
}
func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {
return io.Outbound.UDP(reqAddr)
}

View file

@ -20,6 +20,7 @@ const (
type udpIO interface {
ReceiveMessage() (*protocol.UDPMessage, error)
SendMessage([]byte, *protocol.UDPMessage) error
Hook(data []byte, reqAddr *string) error
UDP(reqAddr string) (UDPConn, error)
}
@ -29,11 +30,12 @@ type udpEventLogger interface {
}
type udpSessionEntry struct {
ID uint32
Conn UDPConn
D *frag.Defragger
Last *utils.AtomicTime
Timeout bool // true if the session is closed due to timeout
ID uint32
Conn UDPConn
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
D *frag.Defragger
Last *utils.AtomicTime
Timeout bool // true if the session is closed due to timeout
}
// Feed feeds a UDP message to the session.
@ -47,7 +49,11 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
if dfMsg == nil {
return 0, nil
}
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
if e.OverrideAddr != "" {
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
} else {
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
}
}
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
@ -177,7 +183,15 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
// Create a new session if not exists
if entry == nil {
// Call the hook
origMsgAddr := msg.Addr
err := m.io.Hook(msg.Data, &msg.Addr)
if err != nil {
return
}
// Log the event
m.eventLogger.New(msg.SessionID, msg.Addr)
// Dial target & create a new session entry
conn, err := m.io.UDP(msg.Addr)
if err != nil {
m.eventLogger.Close(msg.SessionID, err)
@ -189,6 +203,10 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
}
if origMsgAddr != msg.Addr {
// Hook changed the address, enable address override
entry.OverrideAddr = msg.Addr
}
// Start the receive loop for this session
go func() {
err := entry.ReceiveLoop(m.io)

View file

@ -25,6 +25,7 @@ func TestUDPSessionManager(t *testing.T) {
}
return m, nil
})
io.EXPECT().Hook(mock.Anything, mock.Anything).Return(nil, nil)
go sm.Run()