diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go index d8e7e4e..c911e41 100644 --- a/shadowaead_2022/protocol.go +++ b/shadowaead_2022/protocol.go @@ -64,7 +64,7 @@ func init() { random.InitializeSeed() } -func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) { +func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) { var pskList [][]byte if password == "" { return nil, ErrMissingPSK @@ -78,12 +78,13 @@ func NewWithPassword(method string, password string, options ...MethodOption) (s } pskList[i] = kb } - return New(method, pskList, options...) + return New(method, pskList, timeFunc) } -func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) { +func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) { m := &Method{ - name: method, + name: method, + timeFunc: timeFunc, } switch method { @@ -146,9 +147,6 @@ func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks. } m.pskList = pskList - for _, option := range options { - option(m) - } return m, nil } @@ -177,8 +175,10 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci } type Method struct { - name string - keySaltLength int + name string + keySaltLength int + timeFunc func() time.Time + constructor func(key []byte) (cipher.AEAD, error) blockConstructor func(key []byte) (cipher.Block, error) udpCipher cipher.AEAD @@ -222,6 +222,14 @@ type clientConn struct { writer *shadowaead.Writer } +func (m *Method) time() time.Time { + if m.timeFunc != nil { + return m.timeFunc() + } else { + return time.Now() + } +} + func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { pskLen := len(m.pskList) if pskLen < 2 { @@ -280,7 +288,7 @@ func (c *clientConn) writeRequest(payload []byte) error { var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) - common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix()))) var paddingLen int if len(payload) < MaxPaddingLength { paddingLen = mRand.Intn(MaxPaddingLength) + 1 @@ -366,7 +374,7 @@ func (c *clientConn) readResponse() error { return err } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) if diff > 30 { return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") } @@ -526,7 +534,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad } common.Must( header.WriteByte(HeaderTypeClient), - binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, uint64(c.time().Unix())), binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length ) @@ -632,7 +640,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return M.Socksaddr{}, err } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) if diff > 30 { return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") } @@ -641,15 +649,15 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { c.session.window.Add(packetId) } else if sessionId == c.session.lastRemoteSessionId { c.session.lastWindow.Add(packetId) - c.session.lastRemoteSeen = time.Now().Unix() + c.session.lastRemoteSeen = c.time().Unix() } else { if c.session.remoteSessionId != 0 { - if time.Now().Unix()-c.session.lastRemoteSeen < 60 { + if c.time().Unix()-c.session.lastRemoteSeen < 60 { return M.Socksaddr{}, ErrTooManyServerSessions } else { c.session.lastRemoteSessionId = c.session.remoteSessionId c.session.lastWindow = c.session.window - c.session.lastRemoteSeen = time.Now().Unix() + c.session.lastRemoteSeen = c.time().Unix() c.session.lastRemoteCipher = c.session.remoteCipher c.session.window = SlidingWindow{} } @@ -758,7 +766,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } common.Must( buffer.WriteByte(HeaderTypeClient), - binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())), binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length ) diff --git a/shadowaead_2022/protocol_option.go b/shadowaead_2022/protocol_option.go deleted file mode 100644 index c1dd77e..0000000 --- a/shadowaead_2022/protocol_option.go +++ /dev/null @@ -1,3 +0,0 @@ -package shadowaead_2022 - -type MethodOption func(*Method) diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go index cb50d32..5c0ae55 100644 --- a/shadowaead_2022/service.go +++ b/shadowaead_2022/service.go @@ -42,6 +42,7 @@ type Service struct { name string keySaltLength int handler shadowsocks.Handler + timeFunc func() time.Time constructor func(key []byte) (cipher.AEAD, error) blockConstructor func(key []byte) (cipher.Block, error) @@ -54,7 +55,7 @@ type Service struct { udpSessions *cache.LruCache[uint64, *serverUDPSession] } -func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { if password == "" { return nil, ErrMissingPSK } @@ -62,13 +63,14 @@ func NewServiceWithPassword(method string, password string, udpTimeout int64, ha if err != nil { return nil, E.Cause(err, "decode psk") } - return NewService(method, psk, udpTimeout, handler) + return NewService(method, psk, udpTimeout, handler, timeFunc) } -func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { s := &Service{ - name: method, - handler: handler, + name: method, + handler: handler, + timeFunc: timeFunc, replayFilter: replay.NewSimple(60 * time.Second), udpNat: udpnat.New[uint64](udpTimeout, handler), @@ -135,6 +137,14 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M return err } +func (s *Service) time() time.Time { + if s.timeFunc != nil { + return s.timeFunc() + } else { + return time.Now() + } +} + func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength) @@ -183,7 +193,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return err } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) if diff > 30 { return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") } @@ -280,7 +290,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { _headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2) headerFixedChunk := common.Dup(_headerFixedChunk) common.Must(headerFixedChunk.WriteByte(headerType)) - common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix()))) common.Must1(headerFixedChunk.Write(c.requestSalt)) common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen))) @@ -469,7 +479,7 @@ process: if err != nil { goto returnErr } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) if diff > 30 { err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") goto returnErr @@ -539,7 +549,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks binary.Write(header, binary.BigEndian, w.session.sessionId), binary.Write(header, binary.BigEndian, w.session.nextPacketId()), header.WriteByte(HeaderTypeServer), - binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), + binary.Write(header, binary.BigEndian, uint64(w.time().Unix())), binary.Write(header, binary.BigEndian, w.session.remoteSessionId), binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length ) diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index 548be85..768d295 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -33,7 +33,7 @@ type MultiService[U comparable] struct { uCipher map[U]cipher.Block } -func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { +func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { if password == "" { return nil, ErrMissingPSK } @@ -41,10 +41,10 @@ func NewMultiServiceWithPassword[U comparable](method string, password string, u if err != nil { return nil, E.Cause(err, "decode psk") } - return NewMultiService[U](method, iPSK, udpTimeout, handler) + return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc) } -func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { +func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { switch method { case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-256-gcm": @@ -52,7 +52,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, return nil, os.ErrInvalid } - ss, err := NewService(method, iPSK, udpTimeout, handler) + ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc) if err != nil { return nil, err } @@ -192,7 +192,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta if err != nil { return E.Cause(err, "read timestamp") } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) if diff > 30 { return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") } @@ -342,7 +342,7 @@ process: if err != nil { goto returnErr } - diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) + diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) if diff > 30 { err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") goto returnErr diff --git a/shadowaead_2022/service_multi_test.go b/shadowaead_2022/service_multi_test.go index bd6974c..3848e13 100644 --- a/shadowaead_2022/service_multi_test.go +++ b/shadowaead_2022/service_multi_test.go @@ -22,7 +22,7 @@ func TestMultiService(t *testing.T) { var wg sync.WaitGroup - multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}) + multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}, nil) if err != nil { t.Fatal(err) } @@ -31,7 +31,7 @@ func TestMultiService(t *testing.T) { rand.Reader.Read(uPSK[:]) multiService.UpdateUsers([]string{"my user"}, [][]byte{uPSK[:]}) - client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}) + client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, nil) if err != nil { t.Fatal(err) } diff --git a/shadowaead_2022/service_test.go b/shadowaead_2022/service_test.go index 37b6fb4..6c2636f 100644 --- a/shadowaead_2022/service_test.go +++ b/shadowaead_2022/service_test.go @@ -21,12 +21,12 @@ func TestService(t *testing.T) { var wg sync.WaitGroup - service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}) + service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}, nil) if err != nil { t.Fatal(err) } - client, err := shadowaead_2022.New(method, [][]byte{psk[:]}) + client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, nil) if err != nil { t.Fatal(err) } diff --git a/shadowimpl/fetcher.go b/shadowimpl/fetcher.go index a218892..3d5b43a 100644 --- a/shadowimpl/fetcher.go +++ b/shadowimpl/fetcher.go @@ -1,6 +1,8 @@ package shadowimpl import ( + "time" + "github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks/shadowaead" "github.com/sagernet/sing-shadowsocks/shadowaead_2022" @@ -9,7 +11,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func FetchMethod(method string, password string) (shadowsocks.Method, error) { +func FetchMethod(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) { if method == "none" || method == "plain" || method == "dummy" { return shadowsocks.NewNone(), nil } else if common.Contains(shadowstream.List, method) { @@ -17,7 +19,7 @@ func FetchMethod(method string, password string) (shadowsocks.Method, error) { } else if common.Contains(shadowaead.List, method) { return shadowaead.New(method, nil, password) } else if common.Contains(shadowaead_2022.List, method) { - return shadowaead_2022.NewWithPassword(method, password) + return shadowaead_2022.NewWithPassword(method, password, timeFunc) } else { return nil, E.New("shadowsocks: unsupported method ", method) }