Add time func support

This commit is contained in:
世界 2023-02-21 14:19:54 +08:00
parent 31e4666f1e
commit 769c01d6bb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 58 additions and 41 deletions

View file

@ -64,7 +64,7 @@ func init() {
random.InitializeSeed() random.InitializeSeed()
} }
func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) { func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
var pskList [][]byte var pskList [][]byte
if password == "" { if password == "" {
return nil, ErrMissingPSK return nil, ErrMissingPSK
@ -78,12 +78,13 @@ func NewWithPassword(method string, password string, options ...MethodOption) (s
} }
pskList[i] = kb pskList[i] = kb
} }
return New(method, pskList, options...) return New(method, pskList, timeFunc)
} }
func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) { func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) {
m := &Method{ m := &Method{
name: method, name: method,
timeFunc: timeFunc,
} }
switch method { switch method {
@ -146,9 +147,6 @@ func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.
} }
m.pskList = pskList m.pskList = pskList
for _, option := range options {
option(m)
}
return m, nil return m, nil
} }
@ -177,8 +175,10 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci
} }
type Method struct { type Method struct {
name string name string
keySaltLength int keySaltLength int
timeFunc func() time.Time
constructor func(key []byte) (cipher.AEAD, error) constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error) blockConstructor func(key []byte) (cipher.Block, error)
udpCipher cipher.AEAD udpCipher cipher.AEAD
@ -222,6 +222,14 @@ type clientConn struct {
writer *shadowaead.Writer writer *shadowaead.Writer
} }
func (m *Method) time() time.Time {
if m.timeFunc != nil {
return m.timeFunc()
} else {
return time.Now()
}
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
pskLen := len(m.pskList) pskLen := len(m.pskList)
if pskLen < 2 { if pskLen < 2 {
@ -280,7 +288,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix())))
var paddingLen int var paddingLen int
if len(payload) < MaxPaddingLength { if len(payload) < MaxPaddingLength {
paddingLen = mRand.Intn(MaxPaddingLength) + 1 paddingLen = mRand.Intn(MaxPaddingLength) + 1
@ -366,7 +374,7 @@ func (c *clientConn) readResponse() error {
return err return err
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
} }
@ -526,7 +534,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
} }
common.Must( common.Must(
header.WriteByte(HeaderTypeClient), header.WriteByte(HeaderTypeClient),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), binary.Write(header, binary.BigEndian, uint64(c.time().Unix())),
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
) )
@ -632,7 +640,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, err return M.Socksaddr{}, err
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
} }
@ -641,15 +649,15 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
c.session.window.Add(packetId) c.session.window.Add(packetId)
} else if sessionId == c.session.lastRemoteSessionId { } else if sessionId == c.session.lastRemoteSessionId {
c.session.lastWindow.Add(packetId) c.session.lastWindow.Add(packetId)
c.session.lastRemoteSeen = time.Now().Unix() c.session.lastRemoteSeen = c.time().Unix()
} else { } else {
if c.session.remoteSessionId != 0 { if c.session.remoteSessionId != 0 {
if time.Now().Unix()-c.session.lastRemoteSeen < 60 { if c.time().Unix()-c.session.lastRemoteSeen < 60 {
return M.Socksaddr{}, ErrTooManyServerSessions return M.Socksaddr{}, ErrTooManyServerSessions
} else { } else {
c.session.lastRemoteSessionId = c.session.remoteSessionId c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastWindow = c.session.window c.session.lastWindow = c.session.window
c.session.lastRemoteSeen = time.Now().Unix() c.session.lastRemoteSeen = c.time().Unix()
c.session.lastRemoteCipher = c.session.remoteCipher c.session.lastRemoteCipher = c.session.remoteCipher
c.session.window = SlidingWindow{} c.session.window = SlidingWindow{}
} }
@ -758,7 +766,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
} }
common.Must( common.Must(
buffer.WriteByte(HeaderTypeClient), buffer.WriteByte(HeaderTypeClient),
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())), binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())),
binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
) )

View file

@ -1,3 +0,0 @@
package shadowaead_2022
type MethodOption func(*Method)

View file

@ -42,6 +42,7 @@ type Service struct {
name string name string
keySaltLength int keySaltLength int
handler shadowsocks.Handler handler shadowsocks.Handler
timeFunc func() time.Time
constructor func(key []byte) (cipher.AEAD, error) constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error) blockConstructor func(key []byte) (cipher.Block, error)
@ -54,7 +55,7 @@ type Service struct {
udpSessions *cache.LruCache[uint64, *serverUDPSession] udpSessions *cache.LruCache[uint64, *serverUDPSession]
} }
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
if password == "" { if password == "" {
return nil, ErrMissingPSK return nil, ErrMissingPSK
} }
@ -62,13 +63,14 @@ func NewServiceWithPassword(method string, password string, udpTimeout int64, ha
if err != nil { if err != nil {
return nil, E.Cause(err, "decode psk") return nil, E.Cause(err, "decode psk")
} }
return NewService(method, psk, udpTimeout, handler) return NewService(method, psk, udpTimeout, handler, timeFunc)
} }
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
handler: handler, handler: handler,
timeFunc: timeFunc,
replayFilter: replay.NewSimple(60 * time.Second), replayFilter: replay.NewSimple(60 * time.Second),
udpNat: udpnat.New[uint64](udpTimeout, handler), udpNat: udpnat.New[uint64](udpTimeout, handler),
@ -135,6 +137,14 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
return err return err
} }
func (s *Service) time() time.Time {
if s.timeFunc != nil {
return s.timeFunc()
} else {
return time.Now()
}
}
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength) header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
@ -183,7 +193,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return err return err
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
} }
@ -280,7 +290,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2) _headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
headerFixedChunk := common.Dup(_headerFixedChunk) headerFixedChunk := common.Dup(_headerFixedChunk)
common.Must(headerFixedChunk.WriteByte(headerType)) common.Must(headerFixedChunk.WriteByte(headerType))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
common.Must1(headerFixedChunk.Write(c.requestSalt)) common.Must1(headerFixedChunk.Write(c.requestSalt))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen))) common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
@ -469,7 +479,7 @@ process:
if err != nil { if err != nil {
goto returnErr goto returnErr
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr goto returnErr
@ -539,7 +549,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
binary.Write(header, binary.BigEndian, w.session.sessionId), binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()), binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer), header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())), binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId), binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
) )

View file

@ -33,7 +33,7 @@ type MultiService[U comparable] struct {
uCipher map[U]cipher.Block uCipher map[U]cipher.Block
} }
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
if password == "" { if password == "" {
return nil, ErrMissingPSK return nil, ErrMissingPSK
} }
@ -41,10 +41,10 @@ func NewMultiServiceWithPassword[U comparable](method string, password string, u
if err != nil { if err != nil {
return nil, E.Cause(err, "decode psk") return nil, E.Cause(err, "decode psk")
} }
return NewMultiService[U](method, iPSK, udpTimeout, handler) return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
} }
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) { func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
switch method { switch method {
case "2022-blake3-aes-128-gcm": case "2022-blake3-aes-128-gcm":
case "2022-blake3-aes-256-gcm": case "2022-blake3-aes-256-gcm":
@ -52,7 +52,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64,
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
ss, err := NewService(method, iPSK, udpTimeout, handler) ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -192,7 +192,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
if err != nil { if err != nil {
return E.Cause(err, "read timestamp") return E.Cause(err, "read timestamp")
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
} }
@ -342,7 +342,7 @@ process:
if err != nil { if err != nil {
goto returnErr goto returnErr
} }
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 { if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr goto returnErr

View file

@ -22,7 +22,7 @@ func TestMultiService(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}) multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -31,7 +31,7 @@ func TestMultiService(t *testing.T) {
rand.Reader.Read(uPSK[:]) rand.Reader.Read(uPSK[:])
multiService.UpdateUsers([]string{"my user"}, [][]byte{uPSK[:]}) multiService.UpdateUsers([]string{"my user"}, [][]byte{uPSK[:]})
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}) client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -21,12 +21,12 @@ func TestService(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}) service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client, err := shadowaead_2022.New(method, [][]byte{psk[:]}) client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,8 @@
package shadowimpl package shadowimpl
import ( import (
"time"
"github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead" "github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022" "github.com/sagernet/sing-shadowsocks/shadowaead_2022"
@ -9,7 +11,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
func FetchMethod(method string, password string) (shadowsocks.Method, error) { func FetchMethod(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
if method == "none" || method == "plain" || method == "dummy" { if method == "none" || method == "plain" || method == "dummy" {
return shadowsocks.NewNone(), nil return shadowsocks.NewNone(), nil
} else if common.Contains(shadowstream.List, method) { } else if common.Contains(shadowstream.List, method) {
@ -17,7 +19,7 @@ func FetchMethod(method string, password string) (shadowsocks.Method, error) {
} else if common.Contains(shadowaead.List, method) { } else if common.Contains(shadowaead.List, method) {
return shadowaead.New(method, nil, password) return shadowaead.New(method, nil, password)
} else if common.Contains(shadowaead_2022.List, method) { } else if common.Contains(shadowaead_2022.List, method) {
return shadowaead_2022.NewWithPassword(method, password) return shadowaead_2022.NewWithPassword(method, password, timeFunc)
} else { } else {
return nil, E.New("shadowsocks: unsupported method ", method) return nil, E.New("shadowsocks: unsupported method ", method)
} }