Add time func support

This commit is contained in:
世界 2023-02-21 14:19:54 +08:00
parent 31e4666f1e
commit 769c01d6bb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 58 additions and 41 deletions

View file

@ -64,7 +64,7 @@ func init() {
random.InitializeSeed()
}
func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) {
func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
var pskList [][]byte
if password == "" {
return nil, ErrMissingPSK
@ -78,12 +78,13 @@ func NewWithPassword(method string, password string, options ...MethodOption) (s
}
pskList[i] = kb
}
return New(method, pskList, options...)
return New(method, pskList, timeFunc)
}
func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) {
func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) {
m := &Method{
name: method,
name: method,
timeFunc: timeFunc,
}
switch method {
@ -146,9 +147,6 @@ func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.
}
m.pskList = pskList
for _, option := range options {
option(m)
}
return m, nil
}
@ -177,8 +175,10 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci
}
type Method struct {
name string
keySaltLength int
name string
keySaltLength int
timeFunc func() time.Time
constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error)
udpCipher cipher.AEAD
@ -222,6 +222,14 @@ type clientConn struct {
writer *shadowaead.Writer
}
func (m *Method) time() time.Time {
if m.timeFunc != nil {
return m.timeFunc()
} else {
return time.Now()
}
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
pskLen := len(m.pskList)
if pskLen < 2 {
@ -280,7 +288,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix())))
var paddingLen int
if len(payload) < MaxPaddingLength {
paddingLen = mRand.Intn(MaxPaddingLength) + 1
@ -366,7 +374,7 @@ func (c *clientConn) readResponse() error {
return err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
@ -526,7 +534,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
}
common.Must(
header.WriteByte(HeaderTypeClient),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint64(c.time().Unix())),
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
)
@ -632,7 +640,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return M.Socksaddr{}, err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
if diff > 30 {
return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
@ -641,15 +649,15 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
c.session.window.Add(packetId)
} else if sessionId == c.session.lastRemoteSessionId {
c.session.lastWindow.Add(packetId)
c.session.lastRemoteSeen = time.Now().Unix()
c.session.lastRemoteSeen = c.time().Unix()
} else {
if c.session.remoteSessionId != 0 {
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
if c.time().Unix()-c.session.lastRemoteSeen < 60 {
return M.Socksaddr{}, ErrTooManyServerSessions
} else {
c.session.lastRemoteSessionId = c.session.remoteSessionId
c.session.lastWindow = c.session.window
c.session.lastRemoteSeen = time.Now().Unix()
c.session.lastRemoteSeen = c.time().Unix()
c.session.lastRemoteCipher = c.session.remoteCipher
c.session.window = SlidingWindow{}
}
@ -758,7 +766,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
}
common.Must(
buffer.WriteByte(HeaderTypeClient),
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())),
binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
)

View file

@ -1,3 +0,0 @@
package shadowaead_2022
type MethodOption func(*Method)

View file

@ -42,6 +42,7 @@ type Service struct {
name string
keySaltLength int
handler shadowsocks.Handler
timeFunc func() time.Time
constructor func(key []byte) (cipher.AEAD, error)
blockConstructor func(key []byte) (cipher.Block, error)
@ -54,7 +55,7 @@ type Service struct {
udpSessions *cache.LruCache[uint64, *serverUDPSession]
}
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
if password == "" {
return nil, ErrMissingPSK
}
@ -62,13 +63,14 @@ func NewServiceWithPassword(method string, password string, udpTimeout int64, ha
if err != nil {
return nil, E.Cause(err, "decode psk")
}
return NewService(method, psk, udpTimeout, handler)
return NewService(method, psk, udpTimeout, handler, timeFunc)
}
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
s := &Service{
name: method,
handler: handler,
name: method,
handler: handler,
timeFunc: timeFunc,
replayFilter: replay.NewSimple(60 * time.Second),
udpNat: udpnat.New[uint64](udpTimeout, handler),
@ -135,6 +137,14 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
return err
}
func (s *Service) time() time.Time {
if s.timeFunc != nil {
return s.timeFunc()
} else {
return time.Now()
}
}
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
@ -183,7 +193,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
return err
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
@ -280,7 +290,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
headerFixedChunk := common.Dup(_headerFixedChunk)
common.Must(headerFixedChunk.WriteByte(headerType))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
common.Must1(headerFixedChunk.Write(c.requestSalt))
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
@ -469,7 +479,7 @@ process:
if err != nil {
goto returnErr
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr
@ -539,7 +549,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
binary.Write(header, binary.BigEndian, w.session.sessionId),
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
header.WriteByte(HeaderTypeServer),
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
)

View file

@ -33,7 +33,7 @@ type MultiService[U comparable] struct {
uCipher map[U]cipher.Block
}
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
if password == "" {
return nil, ErrMissingPSK
}
@ -41,10 +41,10 @@ func NewMultiServiceWithPassword[U comparable](method string, password string, u
if err != nil {
return nil, E.Cause(err, "decode psk")
}
return NewMultiService[U](method, iPSK, udpTimeout, handler)
return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
}
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
switch method {
case "2022-blake3-aes-128-gcm":
case "2022-blake3-aes-256-gcm":
@ -52,7 +52,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64,
return nil, os.ErrInvalid
}
ss, err := NewService(method, iPSK, udpTimeout, handler)
ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
if err != nil {
return nil, err
}
@ -192,7 +192,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
if err != nil {
return E.Cause(err, "read timestamp")
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 {
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
}
@ -342,7 +342,7 @@ process:
if err != nil {
goto returnErr
}
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
if diff > 30 {
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
goto returnErr

View file

@ -22,7 +22,7 @@ func TestMultiService(t *testing.T) {
var wg sync.WaitGroup
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg})
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}, nil)
if err != nil {
t.Fatal(err)
}
@ -31,7 +31,7 @@ func TestMultiService(t *testing.T) {
rand.Reader.Read(uPSK[:])
multiService.UpdateUsers([]string{"my user"}, [][]byte{uPSK[:]})
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]})
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, nil)
if err != nil {
t.Fatal(err)
}

View file

@ -21,12 +21,12 @@ func TestService(t *testing.T) {
var wg sync.WaitGroup
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg})
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}, nil)
if err != nil {
t.Fatal(err)
}
client, err := shadowaead_2022.New(method, [][]byte{psk[:]})
client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, nil)
if err != nil {
t.Fatal(err)
}

View file

@ -1,6 +1,8 @@
package shadowimpl
import (
"time"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowaead"
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
@ -9,7 +11,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)
func FetchMethod(method string, password string) (shadowsocks.Method, error) {
func FetchMethod(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
if method == "none" || method == "plain" || method == "dummy" {
return shadowsocks.NewNone(), nil
} else if common.Contains(shadowstream.List, method) {
@ -17,7 +19,7 @@ func FetchMethod(method string, password string) (shadowsocks.Method, error) {
} else if common.Contains(shadowaead.List, method) {
return shadowaead.New(method, nil, password)
} else if common.Contains(shadowaead_2022.List, method) {
return shadowaead_2022.NewWithPassword(method, password)
return shadowaead_2022.NewWithPassword(method, password, timeFunc)
} else {
return nil, E.New("shadowsocks: unsupported method ", method)
}