From c78dbb38a14fc228b9958b6527cbf583b07b061d Mon Sep 17 00:00:00 2001 From: Toby Date: Tue, 18 Jun 2024 21:46:25 -0700 Subject: [PATCH] feat: add a Check method to let the implementation decide whether to hook a request --- core/internal/integration_tests/hook_test.go | 2 + .../mocks/mock_RequestHook.go | 47 +++++++++++++++++++ core/server/config.go | 2 + core/server/server.go | 25 ++++++---- 4 files changed, 66 insertions(+), 10 deletions(-) diff --git a/core/internal/integration_tests/hook_test.go b/core/internal/integration_tests/hook_test.go index 43baf4f..1121d13 100644 --- a/core/internal/integration_tests/hook_test.go +++ b/core/internal/integration_tests/hook_test.go @@ -23,6 +23,7 @@ func TestClientServerHookTCP(t *testing.T) { auth := mocks.NewMockAuthenticator(t) auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") hook := mocks.NewMockRequestHook(t) + hook.EXPECT().Check(false, fakeEchoAddr).Return(true).Once() hook.EXPECT().TCP(mock.Anything, mock.Anything).RunAndReturn(func(stream quic.Stream, s *string) ([]byte, error) { assert.Equal(t, fakeEchoAddr, *s) // Change the address @@ -86,6 +87,7 @@ func TestClientServerHookUDP(t *testing.T) { auth := mocks.NewMockAuthenticator(t) auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") hook := mocks.NewMockRequestHook(t) + hook.EXPECT().Check(true, fakeEchoAddr).Return(true).Once() hook.EXPECT().UDP(mock.Anything, mock.Anything).RunAndReturn(func(bytes []byte, s *string) error { assert.Equal(t, fakeEchoAddr, *s) assert.Equal(t, []byte("hello world"), bytes) diff --git a/core/internal/integration_tests/mocks/mock_RequestHook.go b/core/internal/integration_tests/mocks/mock_RequestHook.go index a7a6cc5..5418eaf 100644 --- a/core/internal/integration_tests/mocks/mock_RequestHook.go +++ b/core/internal/integration_tests/mocks/mock_RequestHook.go @@ -20,6 +20,53 @@ func (_m *MockRequestHook) EXPECT() *MockRequestHook_Expecter { return &MockRequestHook_Expecter{mock: &_m.Mock} } +// Check provides a mock function with given fields: isUDP, reqAddr +func (_m *MockRequestHook) Check(isUDP bool, reqAddr string) bool { + ret := _m.Called(isUDP, reqAddr) + + if len(ret) == 0 { + panic("no return value specified for Check") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(bool, string) bool); ok { + r0 = rf(isUDP, reqAddr) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockRequestHook_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check' +type MockRequestHook_Check_Call struct { + *mock.Call +} + +// Check is a helper method to define mock.On call +// - isUDP bool +// - reqAddr string +func (_e *MockRequestHook_Expecter) Check(isUDP interface{}, reqAddr interface{}) *MockRequestHook_Check_Call { + return &MockRequestHook_Check_Call{Call: _e.mock.On("Check", isUDP, reqAddr)} +} + +func (_c *MockRequestHook_Check_Call) Run(run func(isUDP bool, reqAddr string)) *MockRequestHook_Check_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(bool), args[1].(string)) + }) + return _c +} + +func (_c *MockRequestHook_Check_Call) Return(_a0 bool) *MockRequestHook_Check_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRequestHook_Check_Call) RunAndReturn(run func(bool, string) bool) *MockRequestHook_Check_Call { + _c.Call.Return(run) + return _c +} + // TCP provides a mock function with given fields: stream, reqAddr func (_m *MockRequestHook) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) { ret := _m.Called(stream, reqAddr) diff --git a/core/server/config.go b/core/server/config.go index bd8b7ca..f90c820 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -113,12 +113,14 @@ type QUICConfig struct { } // RequestHook allows filtering and modifying requests before the server connects to the remote. +// A request will only be hooked if Check returns true. // The returned byte slice, if not empty, will be sent to the remote before proxying - this is // mainly for "putting back" the content read from the client for sniffing, etc. // Return a non-nil error to abort the connection. // Note that due to the current architectural limitations, it can only inspect the first packet // of a UDP connection. It also cannot put back any data as the first packet is always sent as-is. type RequestHook interface { + Check(isUDP bool, reqAddr string) bool TCP(stream quic.Stream, reqAddr *string) ([]byte, error) UDP(data []byte, reqAddr *string) error } diff --git a/core/server/server.go b/core/server/server.go index 4fc9d41..ba55b31 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -213,15 +213,19 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { } // Call the hook if set var putback []byte + var hooked bool if h.config.RequestHook != nil { - // When RequestHook is enabled, the server should always accept a connection + hooked = h.config.RequestHook.Check(false, reqAddr) + // When the hook is enabled, the server should always accept a connection // so that the client will send whatever request the hook wants to see. // This is essentially a server-side fast-open. - _ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled") - putback, err = h.config.RequestHook.TCP(stream, &reqAddr) - if err != nil { - _ = stream.Close() - return + if hooked { + _ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled") + putback, err = h.config.RequestHook.TCP(stream, &reqAddr) + if err != nil { + _ = stream.Close() + return + } } } // Log the event @@ -231,7 +235,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { // Dial target tConn, err := h.config.Outbound.TCP(reqAddr) if err != nil { - if h.config.RequestHook == nil { + if !hooked { _ = protocol.WriteTCPResponse(stream, false, err.Error()) } _ = stream.Close() @@ -241,7 +245,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { } return } - if h.config.RequestHook == nil { + if !hooked { _ = protocol.WriteTCPResponse(stream, true, "Connected") } // Put back the data if the hook requested @@ -327,10 +331,11 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { } func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error { - if io.RequestHook != nil { + if io.RequestHook != nil && io.RequestHook.Check(true, *reqAddr) { return io.RequestHook.UDP(data, reqAddr) + } else { + return nil } - return nil } func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {