diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 999cef9..19b418c 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -444,7 +444,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad if pskLen > 1 { panic("unsupported chacha extended header") } - dataIndex = buffer.Len() + dataIndex = PacketNonceSize } else { dataIndex = aes.BlockSize } @@ -649,7 +649,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if pskLen > 1 { panic("unsupported chacha extended header") } - dataIndex = buffer.Len() + dataIndex = PacketNonceSize } else { dataIndex = aes.BlockSize } diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index dd2d282..6fdf2c2 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -8,6 +8,7 @@ import ( "io" "math" "net" + "os" "runtime" "sync" "sync/atomic" @@ -69,16 +70,20 @@ func NewService(method string, psk []byte, password string, secureRNG io.Reader, case "2022-blake3-chacha20-poly1305": s.keySaltLength = 32 s.constructor = newChacha20Poly1305 + default: + return nil, os.ErrInvalid } - if len(psk) < s.keySaltLength { - return nil, shadowsocks.ErrBadKey - } else if len(psk) > s.keySaltLength { - psk = Key(psk, s.keySaltLength) - } else if password == "" { - return nil, ErrMissingPasswordPSK - } else { - psk = Key([]byte(password), s.keySaltLength) + if len(psk) != s.keySaltLength { + if len(psk) < s.keySaltLength { + return nil, shadowsocks.ErrBadKey + } else if len(psk) > s.keySaltLength { + psk = Key(psk, s.keySaltLength) + } else if password == "" { + return nil, ErrMissingPasswordPSK + } else { + psk = Key([]byte(password), s.keySaltLength) + } } switch method { @@ -395,7 +400,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks var dataIndex int if w.udpCipher != nil { common.Must1(header.ReadFullFrom(w.secureRNG, PacketNonceSize)) - dataIndex = buffer.Len() + dataIndex = PacketNonceSize } else { dataIndex = aes.BlockSize }