mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-02 03:17:39 +03:00
Add time func support
This commit is contained in:
parent
31e4666f1e
commit
769c01d6bb
7 changed files with 58 additions and 41 deletions
|
@ -64,7 +64,7 @@ func init() {
|
|||
random.InitializeSeed()
|
||||
}
|
||||
|
||||
func NewWithPassword(method string, password string, options ...MethodOption) (shadowsocks.Method, error) {
|
||||
func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
|
||||
var pskList [][]byte
|
||||
if password == "" {
|
||||
return nil, ErrMissingPSK
|
||||
|
@ -78,12 +78,13 @@ func NewWithPassword(method string, password string, options ...MethodOption) (s
|
|||
}
|
||||
pskList[i] = kb
|
||||
}
|
||||
return New(method, pskList, options...)
|
||||
return New(method, pskList, timeFunc)
|
||||
}
|
||||
|
||||
func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.Method, error) {
|
||||
func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) {
|
||||
m := &Method{
|
||||
name: method,
|
||||
name: method,
|
||||
timeFunc: timeFunc,
|
||||
}
|
||||
|
||||
switch method {
|
||||
|
@ -146,9 +147,6 @@ func New(method string, pskList [][]byte, options ...MethodOption) (shadowsocks.
|
|||
}
|
||||
|
||||
m.pskList = pskList
|
||||
for _, option := range options {
|
||||
option(m)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
@ -177,8 +175,10 @@ func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block ci
|
|||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keySaltLength int
|
||||
name string
|
||||
keySaltLength int
|
||||
timeFunc func() time.Time
|
||||
|
||||
constructor func(key []byte) (cipher.AEAD, error)
|
||||
blockConstructor func(key []byte) (cipher.Block, error)
|
||||
udpCipher cipher.AEAD
|
||||
|
@ -222,6 +222,14 @@ type clientConn struct {
|
|||
writer *shadowaead.Writer
|
||||
}
|
||||
|
||||
func (m *Method) time() time.Time {
|
||||
if m.timeFunc != nil {
|
||||
return m.timeFunc()
|
||||
} else {
|
||||
return time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
|
||||
pskLen := len(m.pskList)
|
||||
if pskLen < 2 {
|
||||
|
@ -280,7 +288,7 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
|
||||
fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:]))
|
||||
common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix())))
|
||||
var paddingLen int
|
||||
if len(payload) < MaxPaddingLength {
|
||||
paddingLen = mRand.Intn(MaxPaddingLength) + 1
|
||||
|
@ -366,7 +374,7 @@ func (c *clientConn) readResponse() error {
|
|||
return err
|
||||
}
|
||||
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
}
|
||||
|
@ -526,7 +534,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
|
|||
}
|
||||
common.Must(
|
||||
header.WriteByte(HeaderTypeClient),
|
||||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, uint64(c.time().Unix())),
|
||||
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
|
||||
)
|
||||
|
||||
|
@ -632,7 +640,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
return M.Socksaddr{}, err
|
||||
}
|
||||
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
}
|
||||
|
@ -641,15 +649,15 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
|||
c.session.window.Add(packetId)
|
||||
} else if sessionId == c.session.lastRemoteSessionId {
|
||||
c.session.lastWindow.Add(packetId)
|
||||
c.session.lastRemoteSeen = time.Now().Unix()
|
||||
c.session.lastRemoteSeen = c.time().Unix()
|
||||
} else {
|
||||
if c.session.remoteSessionId != 0 {
|
||||
if time.Now().Unix()-c.session.lastRemoteSeen < 60 {
|
||||
if c.time().Unix()-c.session.lastRemoteSeen < 60 {
|
||||
return M.Socksaddr{}, ErrTooManyServerSessions
|
||||
} else {
|
||||
c.session.lastRemoteSessionId = c.session.remoteSessionId
|
||||
c.session.lastWindow = c.session.window
|
||||
c.session.lastRemoteSeen = time.Now().Unix()
|
||||
c.session.lastRemoteSeen = c.time().Unix()
|
||||
c.session.lastRemoteCipher = c.session.remoteCipher
|
||||
c.session.window = SlidingWindow{}
|
||||
}
|
||||
|
@ -758,7 +766,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
}
|
||||
common.Must(
|
||||
buffer.WriteByte(HeaderTypeClient),
|
||||
binary.Write(buffer, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package shadowaead_2022
|
||||
|
||||
type MethodOption func(*Method)
|
|
@ -42,6 +42,7 @@ type Service struct {
|
|||
name string
|
||||
keySaltLength int
|
||||
handler shadowsocks.Handler
|
||||
timeFunc func() time.Time
|
||||
|
||||
constructor func(key []byte) (cipher.AEAD, error)
|
||||
blockConstructor func(key []byte) (cipher.Block, error)
|
||||
|
@ -54,7 +55,7 @@ type Service struct {
|
|||
udpSessions *cache.LruCache[uint64, *serverUDPSession]
|
||||
}
|
||||
|
||||
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
|
||||
if password == "" {
|
||||
return nil, ErrMissingPSK
|
||||
}
|
||||
|
@ -62,13 +63,14 @@ func NewServiceWithPassword(method string, password string, udpTimeout int64, ha
|
|||
if err != nil {
|
||||
return nil, E.Cause(err, "decode psk")
|
||||
}
|
||||
return NewService(method, psk, udpTimeout, handler)
|
||||
return NewService(method, psk, udpTimeout, handler, timeFunc)
|
||||
}
|
||||
|
||||
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
|
||||
func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
|
||||
s := &Service{
|
||||
name: method,
|
||||
handler: handler,
|
||||
name: method,
|
||||
handler: handler,
|
||||
timeFunc: timeFunc,
|
||||
|
||||
replayFilter: replay.NewSimple(60 * time.Second),
|
||||
udpNat: udpnat.New[uint64](udpTimeout, handler),
|
||||
|
@ -135,6 +137,14 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *Service) time() time.Time {
|
||||
if s.timeFunc != nil {
|
||||
return s.timeFunc()
|
||||
} else {
|
||||
return time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
|
||||
|
||||
|
@ -183,7 +193,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
return err
|
||||
}
|
||||
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
}
|
||||
|
@ -280,7 +290,7 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|||
_headerFixedChunk := buf.StackNewSize(1 + 8 + c.keySaltLength + 2)
|
||||
headerFixedChunk := common.Dup(_headerFixedChunk)
|
||||
common.Must(headerFixedChunk.WriteByte(headerType))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix())))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
|
||||
common.Must1(headerFixedChunk.Write(c.requestSalt))
|
||||
common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
|
||||
|
||||
|
@ -469,7 +479,7 @@ process:
|
|||
if err != nil {
|
||||
goto returnErr
|
||||
}
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
goto returnErr
|
||||
|
@ -539,7 +549,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
|
|||
binary.Write(header, binary.BigEndian, w.session.sessionId),
|
||||
binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
|
||||
header.WriteByte(HeaderTypeServer),
|
||||
binary.Write(header, binary.BigEndian, uint64(time.Now().Unix())),
|
||||
binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
|
||||
binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
|
||||
binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
|
||||
)
|
||||
|
|
|
@ -33,7 +33,7 @@ type MultiService[U comparable] struct {
|
|||
uCipher map[U]cipher.Block
|
||||
}
|
||||
|
||||
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
|
||||
func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
|
||||
if password == "" {
|
||||
return nil, ErrMissingPSK
|
||||
}
|
||||
|
@ -41,10 +41,10 @@ func NewMultiServiceWithPassword[U comparable](method string, password string, u
|
|||
if err != nil {
|
||||
return nil, E.Cause(err, "decode psk")
|
||||
}
|
||||
return NewMultiService[U](method, iPSK, udpTimeout, handler)
|
||||
return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
|
||||
}
|
||||
|
||||
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
|
||||
func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
|
||||
switch method {
|
||||
case "2022-blake3-aes-128-gcm":
|
||||
case "2022-blake3-aes-256-gcm":
|
||||
|
@ -52,7 +52,7 @@ func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64,
|
|||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
ss, err := NewService(method, iPSK, udpTimeout, handler)
|
||||
ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
|
|||
if err != nil {
|
||||
return E.Cause(err, "read timestamp")
|
||||
}
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
}
|
||||
|
@ -342,7 +342,7 @@ process:
|
|||
if err != nil {
|
||||
goto returnErr
|
||||
}
|
||||
diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch))))
|
||||
diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
|
||||
if diff > 30 {
|
||||
err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
|
||||
goto returnErr
|
||||
|
|
|
@ -22,7 +22,7 @@ func TestMultiService(t *testing.T) {
|
|||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg})
|
||||
multiService, err := shadowaead_2022.NewMultiService[string](method, iPSK[:], 500, &multiHandler{t, &wg}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func TestMultiService(t *testing.T) {
|
|||
rand.Reader.Read(uPSK[:])
|
||||
multiService.UpdateUsers([]string{"my user"}, [][]byte{uPSK[:]})
|
||||
|
||||
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]})
|
||||
client, err := shadowaead_2022.New(method, [][]byte{iPSK[:], uPSK[:]}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -21,12 +21,12 @@ func TestService(t *testing.T) {
|
|||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg})
|
||||
service, err := shadowaead_2022.NewService(method, psk[:], 500, &multiHandler{t, &wg}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := shadowaead_2022.New(method, [][]byte{psk[:]})
|
||||
client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package shadowimpl
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
"github.com/sagernet/sing-shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing-shadowsocks/shadowaead_2022"
|
||||
|
@ -9,7 +11,7 @@ import (
|
|||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func FetchMethod(method string, password string) (shadowsocks.Method, error) {
|
||||
func FetchMethod(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
|
||||
if method == "none" || method == "plain" || method == "dummy" {
|
||||
return shadowsocks.NewNone(), nil
|
||||
} else if common.Contains(shadowstream.List, method) {
|
||||
|
@ -17,7 +19,7 @@ func FetchMethod(method string, password string) (shadowsocks.Method, error) {
|
|||
} else if common.Contains(shadowaead.List, method) {
|
||||
return shadowaead.New(method, nil, password)
|
||||
} else if common.Contains(shadowaead_2022.List, method) {
|
||||
return shadowaead_2022.NewWithPassword(method, password)
|
||||
return shadowaead_2022.NewWithPassword(method, password, timeFunc)
|
||||
} else {
|
||||
return nil, E.New("shadowsocks: unsupported method ", method)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue