feat: add a Check method to let the implementation decide whether to hook a request

This commit is contained in:
Toby 2024-06-18 21:46:25 -07:00
parent 2c62a1a1b4
commit c78dbb38a1
4 changed files with 66 additions and 10 deletions

View file

@ -23,6 +23,7 @@ func TestClientServerHookTCP(t *testing.T) {
auth := mocks.NewMockAuthenticator(t) auth := mocks.NewMockAuthenticator(t)
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
hook := mocks.NewMockRequestHook(t) hook := mocks.NewMockRequestHook(t)
hook.EXPECT().Check(false, fakeEchoAddr).Return(true).Once()
hook.EXPECT().TCP(mock.Anything, mock.Anything).RunAndReturn(func(stream quic.Stream, s *string) ([]byte, error) { hook.EXPECT().TCP(mock.Anything, mock.Anything).RunAndReturn(func(stream quic.Stream, s *string) ([]byte, error) {
assert.Equal(t, fakeEchoAddr, *s) assert.Equal(t, fakeEchoAddr, *s)
// Change the address // Change the address
@ -86,6 +87,7 @@ func TestClientServerHookUDP(t *testing.T) {
auth := mocks.NewMockAuthenticator(t) auth := mocks.NewMockAuthenticator(t)
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
hook := mocks.NewMockRequestHook(t) hook := mocks.NewMockRequestHook(t)
hook.EXPECT().Check(true, fakeEchoAddr).Return(true).Once()
hook.EXPECT().UDP(mock.Anything, mock.Anything).RunAndReturn(func(bytes []byte, s *string) error { hook.EXPECT().UDP(mock.Anything, mock.Anything).RunAndReturn(func(bytes []byte, s *string) error {
assert.Equal(t, fakeEchoAddr, *s) assert.Equal(t, fakeEchoAddr, *s)
assert.Equal(t, []byte("hello world"), bytes) assert.Equal(t, []byte("hello world"), bytes)

View file

@ -20,6 +20,53 @@ func (_m *MockRequestHook) EXPECT() *MockRequestHook_Expecter {
return &MockRequestHook_Expecter{mock: &_m.Mock} return &MockRequestHook_Expecter{mock: &_m.Mock}
} }
// Check provides a mock function with given fields: isUDP, reqAddr
func (_m *MockRequestHook) Check(isUDP bool, reqAddr string) bool {
ret := _m.Called(isUDP, reqAddr)
if len(ret) == 0 {
panic("no return value specified for Check")
}
var r0 bool
if rf, ok := ret.Get(0).(func(bool, string) bool); ok {
r0 = rf(isUDP, reqAddr)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockRequestHook_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check'
type MockRequestHook_Check_Call struct {
*mock.Call
}
// Check is a helper method to define mock.On call
// - isUDP bool
// - reqAddr string
func (_e *MockRequestHook_Expecter) Check(isUDP interface{}, reqAddr interface{}) *MockRequestHook_Check_Call {
return &MockRequestHook_Check_Call{Call: _e.mock.On("Check", isUDP, reqAddr)}
}
func (_c *MockRequestHook_Check_Call) Run(run func(isUDP bool, reqAddr string)) *MockRequestHook_Check_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(bool), args[1].(string))
})
return _c
}
func (_c *MockRequestHook_Check_Call) Return(_a0 bool) *MockRequestHook_Check_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockRequestHook_Check_Call) RunAndReturn(run func(bool, string) bool) *MockRequestHook_Check_Call {
_c.Call.Return(run)
return _c
}
// TCP provides a mock function with given fields: stream, reqAddr // TCP provides a mock function with given fields: stream, reqAddr
func (_m *MockRequestHook) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) { func (_m *MockRequestHook) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) {
ret := _m.Called(stream, reqAddr) ret := _m.Called(stream, reqAddr)

View file

@ -113,12 +113,14 @@ type QUICConfig struct {
} }
// RequestHook allows filtering and modifying requests before the server connects to the remote. // RequestHook allows filtering and modifying requests before the server connects to the remote.
// A request will only be hooked if Check returns true.
// The returned byte slice, if not empty, will be sent to the remote before proxying - this is // The returned byte slice, if not empty, will be sent to the remote before proxying - this is
// mainly for "putting back" the content read from the client for sniffing, etc. // mainly for "putting back" the content read from the client for sniffing, etc.
// Return a non-nil error to abort the connection. // Return a non-nil error to abort the connection.
// Note that due to the current architectural limitations, it can only inspect the first packet // Note that due to the current architectural limitations, it can only inspect the first packet
// of a UDP connection. It also cannot put back any data as the first packet is always sent as-is. // of a UDP connection. It also cannot put back any data as the first packet is always sent as-is.
type RequestHook interface { type RequestHook interface {
Check(isUDP bool, reqAddr string) bool
TCP(stream quic.Stream, reqAddr *string) ([]byte, error) TCP(stream quic.Stream, reqAddr *string) ([]byte, error)
UDP(data []byte, reqAddr *string) error UDP(data []byte, reqAddr *string) error
} }

View file

@ -213,15 +213,19 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
} }
// Call the hook if set // Call the hook if set
var putback []byte var putback []byte
var hooked bool
if h.config.RequestHook != nil { if h.config.RequestHook != nil {
// When RequestHook is enabled, the server should always accept a connection hooked = h.config.RequestHook.Check(false, reqAddr)
// When the hook is enabled, the server should always accept a connection
// so that the client will send whatever request the hook wants to see. // so that the client will send whatever request the hook wants to see.
// This is essentially a server-side fast-open. // This is essentially a server-side fast-open.
_ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled") if hooked {
putback, err = h.config.RequestHook.TCP(stream, &reqAddr) _ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled")
if err != nil { putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
_ = stream.Close() if err != nil {
return _ = stream.Close()
return
}
} }
} }
// Log the event // Log the event
@ -231,7 +235,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
// Dial target // Dial target
tConn, err := h.config.Outbound.TCP(reqAddr) tConn, err := h.config.Outbound.TCP(reqAddr)
if err != nil { if err != nil {
if h.config.RequestHook == nil { if !hooked {
_ = protocol.WriteTCPResponse(stream, false, err.Error()) _ = protocol.WriteTCPResponse(stream, false, err.Error())
} }
_ = stream.Close() _ = stream.Close()
@ -241,7 +245,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
} }
return return
} }
if h.config.RequestHook == nil { if !hooked {
_ = protocol.WriteTCPResponse(stream, true, "Connected") _ = protocol.WriteTCPResponse(stream, true, "Connected")
} }
// Put back the data if the hook requested // Put back the data if the hook requested
@ -327,10 +331,11 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
} }
func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error { func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error {
if io.RequestHook != nil { if io.RequestHook != nil && io.RequestHook.Check(true, *reqAddr) {
return io.RequestHook.UDP(data, reqAddr) return io.RequestHook.UDP(data, reqAddr)
} else {
return nil
} }
return nil
} }
func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) { func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {