mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
feat: add a Check method to let the implementation decide whether to hook a request
This commit is contained in:
parent
2c62a1a1b4
commit
c78dbb38a1
4 changed files with 66 additions and 10 deletions
|
@ -23,6 +23,7 @@ func TestClientServerHookTCP(t *testing.T) {
|
|||
auth := mocks.NewMockAuthenticator(t)
|
||||
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
|
||||
hook := mocks.NewMockRequestHook(t)
|
||||
hook.EXPECT().Check(false, fakeEchoAddr).Return(true).Once()
|
||||
hook.EXPECT().TCP(mock.Anything, mock.Anything).RunAndReturn(func(stream quic.Stream, s *string) ([]byte, error) {
|
||||
assert.Equal(t, fakeEchoAddr, *s)
|
||||
// Change the address
|
||||
|
@ -86,6 +87,7 @@ func TestClientServerHookUDP(t *testing.T) {
|
|||
auth := mocks.NewMockAuthenticator(t)
|
||||
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
|
||||
hook := mocks.NewMockRequestHook(t)
|
||||
hook.EXPECT().Check(true, fakeEchoAddr).Return(true).Once()
|
||||
hook.EXPECT().UDP(mock.Anything, mock.Anything).RunAndReturn(func(bytes []byte, s *string) error {
|
||||
assert.Equal(t, fakeEchoAddr, *s)
|
||||
assert.Equal(t, []byte("hello world"), bytes)
|
||||
|
|
|
@ -20,6 +20,53 @@ func (_m *MockRequestHook) EXPECT() *MockRequestHook_Expecter {
|
|||
return &MockRequestHook_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Check provides a mock function with given fields: isUDP, reqAddr
|
||||
func (_m *MockRequestHook) Check(isUDP bool, reqAddr string) bool {
|
||||
ret := _m.Called(isUDP, reqAddr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Check")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(bool, string) bool); ok {
|
||||
r0 = rf(isUDP, reqAddr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockRequestHook_Check_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Check'
|
||||
type MockRequestHook_Check_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Check is a helper method to define mock.On call
|
||||
// - isUDP bool
|
||||
// - reqAddr string
|
||||
func (_e *MockRequestHook_Expecter) Check(isUDP interface{}, reqAddr interface{}) *MockRequestHook_Check_Call {
|
||||
return &MockRequestHook_Check_Call{Call: _e.mock.On("Check", isUDP, reqAddr)}
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_Check_Call) Run(run func(isUDP bool, reqAddr string)) *MockRequestHook_Check_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(bool), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_Check_Call) Return(_a0 bool) *MockRequestHook_Check_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_Check_Call) RunAndReturn(run func(bool, string) bool) *MockRequestHook_Check_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// TCP provides a mock function with given fields: stream, reqAddr
|
||||
func (_m *MockRequestHook) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) {
|
||||
ret := _m.Called(stream, reqAddr)
|
||||
|
|
|
@ -113,12 +113,14 @@ type QUICConfig struct {
|
|||
}
|
||||
|
||||
// RequestHook allows filtering and modifying requests before the server connects to the remote.
|
||||
// A request will only be hooked if Check returns true.
|
||||
// The returned byte slice, if not empty, will be sent to the remote before proxying - this is
|
||||
// mainly for "putting back" the content read from the client for sniffing, etc.
|
||||
// Return a non-nil error to abort the connection.
|
||||
// Note that due to the current architectural limitations, it can only inspect the first packet
|
||||
// of a UDP connection. It also cannot put back any data as the first packet is always sent as-is.
|
||||
type RequestHook interface {
|
||||
Check(isUDP bool, reqAddr string) bool
|
||||
TCP(stream quic.Stream, reqAddr *string) ([]byte, error)
|
||||
UDP(data []byte, reqAddr *string) error
|
||||
}
|
||||
|
|
|
@ -213,15 +213,19 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
}
|
||||
// Call the hook if set
|
||||
var putback []byte
|
||||
var hooked bool
|
||||
if h.config.RequestHook != nil {
|
||||
// When RequestHook is enabled, the server should always accept a connection
|
||||
hooked = h.config.RequestHook.Check(false, reqAddr)
|
||||
// When the hook is enabled, the server should always accept a connection
|
||||
// so that the client will send whatever request the hook wants to see.
|
||||
// This is essentially a server-side fast-open.
|
||||
_ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled")
|
||||
putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return
|
||||
if hooked {
|
||||
_ = protocol.WriteTCPResponse(stream, true, "RequestHook enabled")
|
||||
putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Log the event
|
||||
|
@ -231,7 +235,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
// Dial target
|
||||
tConn, err := h.config.Outbound.TCP(reqAddr)
|
||||
if err != nil {
|
||||
if h.config.RequestHook == nil {
|
||||
if !hooked {
|
||||
_ = protocol.WriteTCPResponse(stream, false, err.Error())
|
||||
}
|
||||
_ = stream.Close()
|
||||
|
@ -241,7 +245,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
}
|
||||
return
|
||||
}
|
||||
if h.config.RequestHook == nil {
|
||||
if !hooked {
|
||||
_ = protocol.WriteTCPResponse(stream, true, "Connected")
|
||||
}
|
||||
// Put back the data if the hook requested
|
||||
|
@ -327,10 +331,11 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
|||
}
|
||||
|
||||
func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error {
|
||||
if io.RequestHook != nil {
|
||||
if io.RequestHook != nil && io.RequestHook.Check(true, *reqAddr) {
|
||||
return io.RequestHook.UDP(data, reqAddr)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue