mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
feat(core): server RequestHook support
This commit is contained in:
parent
4c2a905892
commit
feacb1f85e
13 changed files with 416 additions and 19 deletions
10
core/go.sum
10
core/go.sum
|
@ -12,12 +12,15 @@ github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbV
|
|||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
|
@ -26,15 +29,20 @@ github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
|||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
|
@ -59,8 +67,10 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
|||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -24,3 +24,6 @@ packages:
|
|||
TrafficLogger:
|
||||
config:
|
||||
mockname: MockTrafficLogger
|
||||
RequestHook:
|
||||
config:
|
||||
mockname: MockRequestHook
|
144
core/internal/integration_tests/hook_test.go
Normal file
144
core/internal/integration_tests/hook_test.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package integration_tests
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/apernet/hysteria/core/v2/client"
|
||||
"github.com/apernet/hysteria/core/v2/internal/integration_tests/mocks"
|
||||
"github.com/apernet/hysteria/core/v2/server"
|
||||
"github.com/apernet/quic-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestClientServerHookTCP(t *testing.T) {
|
||||
fakeEchoAddr := "hahanope:6666"
|
||||
realEchoAddr := "127.0.0.1:22333"
|
||||
|
||||
// Create server
|
||||
udpConn, udpAddr, err := serverConn()
|
||||
assert.NoError(t, err)
|
||||
auth := mocks.NewMockAuthenticator(t)
|
||||
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
|
||||
hook := mocks.NewMockRequestHook(t)
|
||||
hook.EXPECT().TCP(mock.Anything, mock.Anything).RunAndReturn(func(stream quic.Stream, s *string) ([]byte, error) {
|
||||
assert.Equal(t, fakeEchoAddr, *s)
|
||||
// Change the address
|
||||
*s = realEchoAddr
|
||||
// Read the first 5 bytes and replace them with "byeee"
|
||||
data := make([]byte, 5)
|
||||
_, err := io.ReadFull(stream, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assert.Equal(t, []byte("hello"), data)
|
||||
return []byte("byeee"), nil
|
||||
}).Once()
|
||||
s, err := server.NewServer(&server.Config{
|
||||
TLSConfig: serverTLSConfig(),
|
||||
Conn: udpConn,
|
||||
RequestHook: hook,
|
||||
Authenticator: auth,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer s.Close()
|
||||
go s.Serve()
|
||||
|
||||
// Create TCP echo server
|
||||
echoListener, err := net.Listen("tcp", realEchoAddr)
|
||||
assert.NoError(t, err)
|
||||
echoServer := &tcpEchoServer{Listener: echoListener}
|
||||
defer echoServer.Close()
|
||||
go echoServer.Serve()
|
||||
|
||||
// Create client
|
||||
c, _, err := client.NewClient(&client.Config{
|
||||
ServerAddr: udpAddr,
|
||||
TLSConfig: client.TLSConfig{InsecureSkipVerify: true},
|
||||
FastOpen: true, // Client MUST have FastOpen for this
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
// Dial TCP
|
||||
conn, err := c.TCP(fakeEchoAddr)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send and receive data
|
||||
sData := []byte("hello world")
|
||||
_, err = conn.Write(sData)
|
||||
assert.NoError(t, err)
|
||||
rData := make([]byte, len(sData))
|
||||
_, err = io.ReadFull(conn, rData)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("byeee world"), rData)
|
||||
}
|
||||
|
||||
func TestClientServerHookUDP(t *testing.T) {
|
||||
fakeEchoAddr := "hahanope:6666"
|
||||
realEchoAddr := "127.0.0.1:22333"
|
||||
|
||||
// Create server
|
||||
udpConn, udpAddr, err := serverConn()
|
||||
assert.NoError(t, err)
|
||||
auth := mocks.NewMockAuthenticator(t)
|
||||
auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody")
|
||||
hook := mocks.NewMockRequestHook(t)
|
||||
hook.EXPECT().UDP(mock.Anything, mock.Anything).RunAndReturn(func(bytes []byte, s *string) error {
|
||||
assert.Equal(t, fakeEchoAddr, *s)
|
||||
assert.Equal(t, []byte("hello world"), bytes)
|
||||
// Change the address
|
||||
*s = realEchoAddr
|
||||
return nil
|
||||
}).Once()
|
||||
s, err := server.NewServer(&server.Config{
|
||||
TLSConfig: serverTLSConfig(),
|
||||
Conn: udpConn,
|
||||
RequestHook: hook,
|
||||
Authenticator: auth,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer s.Close()
|
||||
go s.Serve()
|
||||
|
||||
// Create UDP echo server
|
||||
echoConn, err := net.ListenPacket("udp", realEchoAddr)
|
||||
assert.NoError(t, err)
|
||||
echoServer := &udpEchoServer{Conn: echoConn}
|
||||
defer echoServer.Close()
|
||||
go echoServer.Serve()
|
||||
|
||||
// Create client
|
||||
c, _, err := client.NewClient(&client.Config{
|
||||
ServerAddr: udpAddr,
|
||||
TLSConfig: client.TLSConfig{InsecureSkipVerify: true},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
// Listen UDP
|
||||
conn, err := c.UDP()
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send and receive data
|
||||
sData := []byte("hello world")
|
||||
err = conn.Send(sData, fakeEchoAddr)
|
||||
assert.NoError(t, err)
|
||||
rData, rAddr, err := conn.Receive()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sData, rData)
|
||||
assert.Equal(t, realEchoAddr, rAddr)
|
||||
|
||||
// Subsequent packets should also be sent to the real echo server
|
||||
sData = []byte("never stop fighting")
|
||||
err = conn.Send(sData, fakeEchoAddr)
|
||||
assert.NoError(t, err)
|
||||
rData, rAddr, err = conn.Receive()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sData, rData)
|
||||
assert.Equal(t, realEchoAddr, rAddr)
|
||||
}
|
141
core/internal/integration_tests/mocks/mock_RequestHook.go
Normal file
141
core/internal/integration_tests/mocks/mock_RequestHook.go
Normal file
|
@ -0,0 +1,141 @@
|
|||
// Code generated by mockery v2.43.0. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
quic "github.com/apernet/quic-go"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockRequestHook is an autogenerated mock type for the RequestHook type
|
||||
type MockRequestHook struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockRequestHook_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockRequestHook) EXPECT() *MockRequestHook_Expecter {
|
||||
return &MockRequestHook_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// TCP provides a mock function with given fields: stream, reqAddr
|
||||
func (_m *MockRequestHook) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) {
|
||||
ret := _m.Called(stream, reqAddr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for TCP")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(quic.Stream, *string) ([]byte, error)); ok {
|
||||
return rf(stream, reqAddr)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(quic.Stream, *string) []byte); ok {
|
||||
r0 = rf(stream, reqAddr)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(quic.Stream, *string) error); ok {
|
||||
r1 = rf(stream, reqAddr)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockRequestHook_TCP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TCP'
|
||||
type MockRequestHook_TCP_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// TCP is a helper method to define mock.On call
|
||||
// - stream quic.Stream
|
||||
// - reqAddr *string
|
||||
func (_e *MockRequestHook_Expecter) TCP(stream interface{}, reqAddr interface{}) *MockRequestHook_TCP_Call {
|
||||
return &MockRequestHook_TCP_Call{Call: _e.mock.On("TCP", stream, reqAddr)}
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_TCP_Call) Run(run func(stream quic.Stream, reqAddr *string)) *MockRequestHook_TCP_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(quic.Stream), args[1].(*string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_TCP_Call) Return(_a0 []byte, _a1 error) *MockRequestHook_TCP_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_TCP_Call) RunAndReturn(run func(quic.Stream, *string) ([]byte, error)) *MockRequestHook_TCP_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UDP provides a mock function with given fields: data, reqAddr
|
||||
func (_m *MockRequestHook) UDP(data []byte, reqAddr *string) error {
|
||||
ret := _m.Called(data, reqAddr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UDP")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, *string) error); ok {
|
||||
r0 = rf(data, reqAddr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockRequestHook_UDP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UDP'
|
||||
type MockRequestHook_UDP_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UDP is a helper method to define mock.On call
|
||||
// - data []byte
|
||||
// - reqAddr *string
|
||||
func (_e *MockRequestHook_Expecter) UDP(data interface{}, reqAddr interface{}) *MockRequestHook_UDP_Call {
|
||||
return &MockRequestHook_UDP_Call{Call: _e.mock.On("UDP", data, reqAddr)}
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_UDP_Call) Run(run func(data []byte, reqAddr *string)) *MockRequestHook_UDP_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].(*string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_UDP_Call) Return(_a0 error) *MockRequestHook_UDP_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockRequestHook_UDP_Call) RunAndReturn(run func([]byte, *string) error) *MockRequestHook_UDP_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockRequestHook creates a new instance of MockRequestHook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockRequestHook(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockRequestHook {
|
||||
mock := &MockRequestHook{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/apernet/hysteria/core/v2/errors"
|
||||
"github.com/apernet/hysteria/core/v2/internal/pmtud"
|
||||
"github.com/apernet/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -22,6 +23,7 @@ type Config struct {
|
|||
TLSConfig TLSConfig
|
||||
QUICConfig QUICConfig
|
||||
Conn net.PacketConn
|
||||
RequestHook RequestHook
|
||||
Outbound Outbound
|
||||
BandwidthConfig BandwidthConfig
|
||||
IgnoreClientBandwidth bool
|
||||
|
@ -110,6 +112,17 @@ type QUICConfig struct {
|
|||
DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms.
|
||||
}
|
||||
|
||||
// RequestHook allows filtering and modifying requests before the server connects to the remote.
|
||||
// The returned byte slice, if not empty, will be sent to the remote before proxying - this is
|
||||
// mainly for "putting back" the content read from the client for sniffing, etc.
|
||||
// Return a non-nil error to abort the connection.
|
||||
// Note that due to the current architectural limitations, it can only inspect the first packet
|
||||
// of a UDP connection. It also cannot put back any data as the first packet is always sent as-is.
|
||||
type RequestHook interface {
|
||||
TCP(stream quic.Stream, reqAddr *string) ([]byte, error)
|
||||
UDP(data []byte, reqAddr *string) error
|
||||
}
|
||||
|
||||
// Outbound provides the implementation of how the server should connect to remote servers.
|
||||
// Although UDP includes a reqAddr, the implementation does not necessarily have to use it
|
||||
// to make a "connected" UDP connection that does not accept packets from other addresses.
|
||||
|
|
|
@ -20,6 +20,53 @@ func (_m *mockUDPIO) EXPECT() *mockUDPIO_Expecter {
|
|||
return &mockUDPIO_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Hook provides a mock function with given fields: data, reqAddr
|
||||
func (_m *mockUDPIO) Hook(data []byte, reqAddr *string) error {
|
||||
ret := _m.Called(data, reqAddr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Hook")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, *string) error); ok {
|
||||
r0 = rf(data, reqAddr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// mockUDPIO_Hook_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Hook'
|
||||
type mockUDPIO_Hook_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Hook is a helper method to define mock.On call
|
||||
// - data []byte
|
||||
// - reqAddr *string
|
||||
func (_e *mockUDPIO_Expecter) Hook(data interface{}, reqAddr interface{}) *mockUDPIO_Hook_Call {
|
||||
return &mockUDPIO_Hook_Call{Call: _e.mock.On("Hook", data, reqAddr)}
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) Run(run func(data []byte, reqAddr *string)) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].(*string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) Return(_a0 error) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *mockUDPIO_Hook_Call) RunAndReturn(run func([]byte, *string) error) *mockUDPIO_Hook_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReceiveMessage provides a mock function with given fields:
|
||||
func (_m *mockUDPIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
ret := _m.Called()
|
||||
|
|
|
@ -170,7 +170,7 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
if !h.config.DisableUDP {
|
||||
go func() {
|
||||
sm := newUDPSessionManager(
|
||||
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
|
||||
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.RequestHook, h.config.Outbound},
|
||||
&udpEventLoggerImpl{h.conn, id, h.config.EventLogger},
|
||||
h.config.UDPIdleTimeout)
|
||||
h.udpSM = sm
|
||||
|
@ -211,6 +211,16 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
// Call the hook if set
|
||||
var putback []byte
|
||||
if h.config.RequestHook != nil {
|
||||
putback, err = h.config.RequestHook.TCP(stream, &reqAddr)
|
||||
if err != nil {
|
||||
_ = protocol.WriteTCPResponse(stream, false, err.Error())
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
// Log the event
|
||||
if h.config.EventLogger != nil {
|
||||
h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr)
|
||||
|
@ -227,6 +237,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
return
|
||||
}
|
||||
_ = protocol.WriteTCPResponse(stream, true, "")
|
||||
// Put back the data if the hook requested
|
||||
if len(putback) > 0 {
|
||||
_, _ = tConn.Write(putback)
|
||||
}
|
||||
// Start proxying
|
||||
if h.config.TrafficLogger != nil {
|
||||
err = copyTwoWayWithLogger(h.authID, stream, tConn, h.config.TrafficLogger)
|
||||
|
@ -260,6 +274,7 @@ type udpIOImpl struct {
|
|||
Conn quic.Connection
|
||||
AuthID string
|
||||
TrafficLogger TrafficLogger
|
||||
RequestHook RequestHook
|
||||
Outbound Outbound
|
||||
}
|
||||
|
||||
|
@ -304,6 +319,13 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
|||
return io.Conn.SendDatagram(buf[:msgN])
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) Hook(data []byte, reqAddr *string) error {
|
||||
if io.RequestHook != nil {
|
||||
return io.RequestHook.UDP(data, reqAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) {
|
||||
return io.Outbound.UDP(reqAddr)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ const (
|
|||
type udpIO interface {
|
||||
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||
SendMessage([]byte, *protocol.UDPMessage) error
|
||||
Hook(data []byte, reqAddr *string) error
|
||||
UDP(reqAddr string) (UDPConn, error)
|
||||
}
|
||||
|
||||
|
@ -29,11 +30,12 @@ type udpEventLogger interface {
|
|||
}
|
||||
|
||||
type udpSessionEntry struct {
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Timeout bool // true if the session is closed due to timeout
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Timeout bool // true if the session is closed due to timeout
|
||||
}
|
||||
|
||||
// Feed feeds a UDP message to the session.
|
||||
|
@ -47,7 +49,11 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
|
|||
if dfMsg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
if e.OverrideAddr != "" {
|
||||
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
|
||||
} else {
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||
|
@ -177,7 +183,15 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
|
||||
// Create a new session if not exists
|
||||
if entry == nil {
|
||||
// Call the hook
|
||||
origMsgAddr := msg.Addr
|
||||
err := m.io.Hook(msg.Data, &msg.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Log the event
|
||||
m.eventLogger.New(msg.SessionID, msg.Addr)
|
||||
// Dial target & create a new session entry
|
||||
conn, err := m.io.UDP(msg.Addr)
|
||||
if err != nil {
|
||||
m.eventLogger.Close(msg.SessionID, err)
|
||||
|
@ -189,6 +203,10 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
D: &frag.Defragger{},
|
||||
Last: utils.NewAtomicTime(time.Now()),
|
||||
}
|
||||
if origMsgAddr != msg.Addr {
|
||||
// Hook changed the address, enable address override
|
||||
entry.OverrideAddr = msg.Addr
|
||||
}
|
||||
// Start the receive loop for this session
|
||||
go func() {
|
||||
err := entry.ReceiveLoop(m.io)
|
||||
|
|
|
@ -25,6 +25,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
}
|
||||
return m, nil
|
||||
})
|
||||
io.EXPECT().Hook(mock.Anything, mock.Anything).Return(nil, nil)
|
||||
|
||||
go sm.Run()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue