sing/protocol/socks/protocol.go
2022-05-02 11:18:46 +08:00

365 lines
8.5 KiB
Go

package socks
import (
"bytes"
"io"
"net"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/rw"
)
//+----+----------+----------+
//|VER | NMETHODS | METHODS |
//+----+----------+----------+
//| 1 | 1 | 1 to 255 |
//+----+----------+----------+
type AuthRequest struct {
Version byte
Methods []byte
}
func WriteAuthRequest(writer io.Writer, request *AuthRequest) error {
err := rw.WriteByte(writer, request.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, byte(len(request.Methods)))
if err != nil {
return err
}
return rw.WriteBytes(writer, request.Methods)
}
func ReadAuthRequest(reader io.Reader) (*AuthRequest, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != Version5 {
return nil, &UnsupportedVersionException{version}
}
methodLen, err := rw.ReadByte(reader)
if err != nil {
return nil, E.Cause(err, "read socks auth methods length")
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, E.CauseF(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
version,
methods,
}
return request, nil
}
func ReadAuthRequest0(reader io.Reader) (*AuthRequest, error) {
methodLen, err := rw.ReadByte(reader)
if err != nil {
return nil, E.Cause(err, "read socks auth methods length")
}
methods, err := rw.ReadBytes(reader, int(methodLen))
if err != nil {
return nil, E.CauseF(err, "read socks auth methods, length ", methodLen)
}
request := &AuthRequest{
Version5,
methods,
}
return request, nil
}
//+----+--------+
//|VER | METHOD |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type AuthResponse struct {
Version byte
Method byte
}
func WriteAuthResponse(writer io.Writer, response *AuthResponse) error {
err := rw.WriteByte(writer, response.Version)
if err != nil {
return err
}
return rw.WriteByte(writer, response.Method)
}
func ReadAuthResponse(reader io.Reader) (*AuthResponse, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != Version5 {
return nil, &UnsupportedVersionException{version}
}
method, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
response := &AuthResponse{
Version: version,
Method: method,
}
return response, nil
}
//+----+------+----------+------+----------+
//|VER | ULEN | UNAME | PLEN | PASSWD |
//+----+------+----------+------+----------+
//| 1 | 1 | 1 to 255 | 1 | 1 to 255 |
//+----+------+----------+------+----------+
type UsernamePasswordAuthRequest struct {
Username string
Password string
}
func WriteUsernamePasswordAuthRequest(writer io.Writer, request *UsernamePasswordAuthRequest) error {
err := rw.WriteByte(writer, UsernamePasswordVersion1)
if err != nil {
return err
}
err = M.WriteString(writer, "username", request.Username)
if err != nil {
return err
}
return M.WriteString(writer, "password", request.Password)
}
func ReadUsernamePasswordAuthRequest(reader io.Reader) (*UsernamePasswordAuthRequest, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != UsernamePasswordVersion1 {
return nil, &UnsupportedVersionException{version}
}
username, err := M.ReadString(reader)
if err != nil {
return nil, err
}
password, err := M.ReadString(reader)
if err != nil {
return nil, err
}
request := &UsernamePasswordAuthRequest{
Username: username,
Password: password,
}
return request, nil
}
//+----+--------+
//|VER | STATUS |
//+----+--------+
//| 1 | 1 |
//+----+--------+
type UsernamePasswordAuthResponse struct {
Status byte
}
func WriteUsernamePasswordAuthResponse(writer io.Writer, response *UsernamePasswordAuthResponse) error {
err := rw.WriteByte(writer, UsernamePasswordVersion1)
if err != nil {
return err
}
return rw.WriteByte(writer, response.Status)
}
func ReadUsernamePasswordAuthResponse(reader io.Reader) (*UsernamePasswordAuthResponse, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if version != UsernamePasswordVersion1 {
return nil, &UnsupportedVersionException{version}
}
status, err := rw.ReadByte(reader)
if status != UsernamePasswordStatusSuccess {
status = UsernamePasswordStatusFailure
}
response := &UsernamePasswordAuthResponse{
Status: status,
}
return response, nil
}
//+----+-----+-------+------+----------+----------+
//|VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Request struct {
Version byte
Command byte
Destination *M.AddrPort
}
func WriteRequest(writer io.Writer, request *Request) error {
err := rw.WriteByte(writer, request.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, request.Command)
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
return AddressSerializer.WriteAddrPort(writer, request.Destination)
}
func ReadRequest(reader io.Reader) (*Request, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if !(version == Version4 || version == Version5) {
return nil, &UnsupportedVersionException{version}
}
command, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if command != CommandConnect && command != CommandUDPAssociate {
return nil, &UnsupportedCommandException{command}
}
err = rw.Skip(reader)
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
request := &Request{
Version: version,
Command: command,
Destination: addrPort,
}
return request, nil
}
//+----+-----+-------+------+----------+----------+
//|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
//+----+-----+-------+------+----------+----------+
//| 1 | 1 | X'00' | 1 | Variable | 2 |
//+----+-----+-------+------+----------+----------+
type Response struct {
Version byte
ReplyCode ReplyCode
Bind *M.AddrPort
}
func WriteResponse(writer io.Writer, response *Response) error {
err := rw.WriteByte(writer, response.Version)
if err != nil {
return err
}
err = rw.WriteByte(writer, byte(response.ReplyCode))
if err != nil {
return err
}
err = rw.WriteZero(writer)
if err != nil {
return err
}
if response.Bind == nil {
return AddressSerializer.WriteAddrPort(writer, M.AddrPortFrom(M.AddrFromIP(net.IPv4zero), 0))
}
return AddressSerializer.WriteAddrPort(writer, response.Bind)
}
func ReadResponse(reader io.Reader) (*Response, error) {
version, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
if !(version == Version4 || version == Version5) {
return nil, &UnsupportedVersionException{version}
}
replyCode, err := rw.ReadByte(reader)
if err != nil {
return nil, err
}
err = rw.Skip(reader)
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
response := &Response{
Version: version,
ReplyCode: ReplyCode(replyCode),
Bind: addrPort,
}
return response, nil
}
//+----+------+------+----------+----------+----------+
//|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
//+----+------+------+----------+----------+----------+
//| 2 | 1 | 1 | Variable | 2 | Variable |
//+----+------+------+----------+----------+----------+
type AssociatePacket struct {
Fragment byte
Destination *M.AddrPort
Data []byte
}
func DecodeAssociatePacket(buffer *buf.Buffer) (*AssociatePacket, error) {
if buffer.Len() < 5 {
return nil, E.New("insufficient length")
}
fragment := buffer.Byte(2)
reader := bytes.NewReader(buffer.Bytes())
err := common.Error(reader.Seek(3, io.SeekStart))
if err != nil {
return nil, err
}
addrPort, err := AddressSerializer.ReadAddrPort(reader)
if err != nil {
return nil, err
}
buffer.Advance(reader.Len())
packet := &AssociatePacket{
Fragment: fragment,
Destination: addrPort,
Data: buffer.Bytes(),
}
return packet, nil
}
func EncodeAssociatePacket(packet *AssociatePacket, buffer *buf.Buffer) error {
err := rw.WriteZeroN(buffer, 2)
if err != nil {
return err
}
err = rw.WriteByte(buffer, packet.Fragment)
if err != nil {
return err
}
err = AddressSerializer.WriteAddrPort(buffer, packet.Destination)
if err != nil {
return err
}
_, err = buffer.Write(packet.Data)
return err
}