mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
feat(core): server RequestHook support
This commit is contained in:
parent
4c2a905892
commit
feacb1f85e
13 changed files with 416 additions and 19 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/apernet/hysteria/core/v2/errors"
|
||||
"github.com/apernet/hysteria/core/v2/internal/pmtud"
|
||||
"github.com/apernet/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -22,6 +23,7 @@ type Config struct {
|
|||
TLSConfig TLSConfig
|
||||
QUICConfig QUICConfig
|
||||
Conn net.PacketConn
|
||||
RequestHook RequestHook
|
||||
Outbound Outbound
|
||||
BandwidthConfig BandwidthConfig
|
||||
IgnoreClientBandwidth bool
|
||||
|
@ -110,6 +112,17 @@ type QUICConfig struct {
|
|||
DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms.
|
||||
}
|
||||
|
||||
// RequestHook allows filtering and modifying requests before the server connects to the remote.
|
||||
// The returned byte slice, if not empty, will be sent to the remote before proxying - this is
|
||||
// mainly for "putting back" the content read from the client for sniffing, etc.
|
||||
// Return a non-nil error to abort the connection.
|
||||
// Note that due to the current architectural limitations, it can only inspect the first packet
|
||||
// of a UDP connection. It also cannot put back any data as the first packet is always sent as-is.
|
||||
type RequestHook interface {
|
||||
TCP(stream quic.Stream, reqAddr *string) ([]byte, error)
|
||||
UDP(data []byte, reqAddr *string) error
|
||||
}
|
||||
|
||||
// Outbound provides the implementation of how the server should connect to remote servers.
|
||||
// Although UDP includes a reqAddr, the implementation does not necessarily have to use it
|
||||
// to make a "connected" UDP connection that does not accept packets from other addresses.
|
||||
|
|
|
@ -20,6 +20,53 @@ func (_m *mockUDPIO) EXPECT() *mockUDPIO_Expecter {
|
|||
return &mockUDPIO_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Hook provides a mock function with given fields: data, reqAddr
|
||||
func (_m *mockUDPIO) Hook(data []byte, reqAddr *string) error {
|
||||
ret := _m.Called(data, reqAddr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Hook")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, *string) error); ok {
|
||||
r0 = rf(data, reqAddr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// mockUDPIO_Hook_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Hook'
|
||||
type mockUDPIO_Hook_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Hook is a helper method to define mock.On call
|
||||
// - data []byte
|
||||
// - reqAddr *string
|
||||
func (_e *mockUDPIO_Expecter) Hook(data interface{}, reqAddr interface{}) *mockUDPIO_Hook_Call {
|
||||
return &mockUDPIO_Hook_Call{Call: _e.mock.On("Hook", data, reqAddr)}
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) Run(run func(data []byte, reqAddr *string)) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].(*string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) Return(_a0 error) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) RunAndReturn(run func([]byte, *string) error) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReceiveMessage provides a mock function with given fields:
|
||||
func (_m *mockUDPIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
ret := _m.Called()
|
||||
|
|
|
@ -170,7 +170,7 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
if !h.config.DisableUDP {
|
||||
go func() {
|
||||
sm := newUDPSessionManager(
|
||||
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
|
||||
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.RequestHook, h.config.Outbound},
|
||||
&udpEventLoggerImpl{h.conn, id, h.config.EventLogger},
|
||||
h.config.UDPIdleTimeout)
|
||||
h.udpSM = sm
|
||||
|
@ -211,6 +211,16 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
// Call the hook if set
|
||||
var putback []byte
|
||||
if h.config.RequestHook != nil {
|
||||
putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
|
||||
if err != nil {
|
||||
_ = protocol.WriteTCPResponse(stream, false, err.Error())
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
// Log the event
|
||||
if h.config.EventLogger != nil {
|
||||
h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr)
|
||||
|
@ -227,6 +237,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
return
|
||||
}
|
||||
_ = protocol.WriteTCPResponse(stream, true, "")
|
||||
// Put back the data if the hook requested
|
||||
if len(putback) > 0 {
|
||||
_, _ = tConn.Write(putback)
|
||||
}
|
||||
// Start proxying
|
||||
if h.config.TrafficLogger != nil {
|
||||
err = copyTwoWayWithLogger(h.authID, stream, tConn, h.config.TrafficLogger)
|
||||
|
@ -260,6 +274,7 @@ type udpIOImpl struct {
|
|||
Conn quic.Connection
|
||||
AuthID string
|
||||
TrafficLogger TrafficLogger
|
||||
RequestHook RequestHook
|
||||
Outbound Outbound
|
||||
}
|
||||
|
||||
|
@ -304,6 +319,13 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
|||
return io.Conn.SendDatagram(buf[:msgN])
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error {
|
||||
if io.RequestHook != nil {
|
||||
return io.RequestHook.UDP(data, reqAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {
|
||||
return io.Outbound.UDP(reqAddr)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ const (
|
|||
type udpIO interface {
|
||||
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||
SendMessage([]byte, *protocol.UDPMessage) error
|
||||
Hook(data []byte, reqAddr *string) error
|
||||
UDP(reqAddr string) (UDPConn, error)
|
||||
}
|
||||
|
||||
|
@ -29,11 +30,12 @@ type udpEventLogger interface {
|
|||
}
|
||||
|
||||
type udpSessionEntry struct {
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Timeout bool // true if the session is closed due to timeout
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Timeout bool // true if the session is closed due to timeout
|
||||
}
|
||||
|
||||
// Feed feeds a UDP message to the session.
|
||||
|
@ -47,7 +49,11 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
|
|||
if dfMsg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
if e.OverrideAddr != "" {
|
||||
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
|
||||
} else {
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||
|
@ -177,7 +183,15 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
|
||||
// Create a new session if not exists
|
||||
if entry == nil {
|
||||
// Call the hook
|
||||
origMsgAddr := msg.Addr
|
||||
err := m.io.Hook(msg.Data, &msg.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Log the event
|
||||
m.eventLogger.New(msg.SessionID, msg.Addr)
|
||||
// Dial target & create a new session entry
|
||||
conn, err := m.io.UDP(msg.Addr)
|
||||
if err != nil {
|
||||
m.eventLogger.Close(msg.SessionID, err)
|
||||
|
@ -189,6 +203,10 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
D: &frag.Defragger{},
|
||||
Last: utils.NewAtomicTime(time.Now()),
|
||||
}
|
||||
if origMsgAddr != msg.Addr {
|
||||
// Hook changed the address, enable address override
|
||||
entry.OverrideAddr = msg.Addr
|
||||
}
|
||||
// Start the receive loop for this session
|
||||
go func() {
|
||||
err := entry.ReceiveLoop(m.io)
|
||||
|
|
|
@ -25,6 +25,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
}
|
||||
return m, nil
|
||||
})
|
||||
io.EXPECT().Hook(mock.Anything, mock.Anything).Return(nil, nil)
|
||||
|
||||
go sm.Run()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue