diff --git a/common/auth/auth.go b/common/auth/auth.go index b1be60d..81597b6 100644 --- a/common/auth/auth.go +++ b/common/auth/auth.go @@ -1,6 +1,23 @@ package auth -import "github.com/sagernet/sing/common" +import ( + "crypto/md5" + "encoding/hex" + "fmt" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/param" +) + +const Realm = "sing-box" + +type Challenge struct { + Username string + Nonce string + CNonce string + Nc string + Response string +} type User struct { Username string @@ -28,3 +45,55 @@ func (au *Authenticator) Verify(username string, password string) bool { passwordList, ok := au.userMap[username] return ok && common.Contains(passwordList, password) } + +func (au *Authenticator) VerifyDigest(method string, uri string, s string) (string, bool) { + c, err := ParseChallenge(s) + if err != nil { + return "", false + } + if c.Username == "" || c.Nonce == "" || c.Nc == "" || c.CNonce == "" || c.Response == "" { + return "", false + } + passwordList, ok := au.userMap[c.Username] + if ok { + for _, password := range passwordList { + ha1 := md5str(c.Username + ":" + Realm + ":" + password) + ha2 := md5str(method + ":" + uri) + resp := md5str(ha1 + ":" + c.Nonce + ":" + c.Nc + ":" + c.CNonce + ":auth:" + ha2) + if resp == c.Response { + return c.Username, true + } + } + } + return "", false +} + +func ParseChallenge(s string) (*Challenge, error) { + pp, err := param.Parse(s) + if err != nil { + return nil, fmt.Errorf("digest: invalid challenge: %w", err) + } + var c Challenge + + for _, p := range pp { + switch p.Key { + case "username": + c.Username = p.Value + case "nonce": + c.Nonce = p.Value + case "cnonce": + c.CNonce = p.Value + case "nc": + c.Nc = p.Value + case "response": + c.Response = p.Value + } + } + return &c, nil +} + +func md5str(str string) string { + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/common/param/param.go b/common/param/param.go new file mode 100644 index 0000000..7fbd4d8 --- /dev/null +++ b/common/param/param.go @@ -0,0 +1,189 @@ +package param + +// code retrieve from https://github.com/icholy/digest/tree/master/internal/param + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" +) + +// Param is a key/value header parameter +type Param struct { + Key string + Value string + Quote bool +} + +// String returns the formatted parameter +func (p Param) String() string { + if p.Quote { + return p.Key + "=" + strconv.Quote(p.Value) + } + return p.Key + "=" + p.Value +} + +// Format formats the parameters to be included in the header +func Format(pp ...Param) string { + var b strings.Builder + for i, p := range pp { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(p.String()) + } + return b.String() +} + +// Parse parses the header parameters +func Parse(s string) ([]Param, error) { + var pp []Param + br := bufio.NewReader(strings.NewReader(s)) + for i := 0; true; i++ { + // skip whitespace + if err := skipWhite(br); err != nil { + return nil, err + } + // see if there's more to read + if _, err := br.Peek(1); err == io.EOF { + break + } + // read key/value pair + p, err := parseParam(br, i == 0) + if err != nil { + return nil, fmt.Errorf("param: %w", err) + } + pp = append(pp, p) + } + return pp, nil +} + +func parseIdent(br *bufio.Reader) (string, error) { + var ident []byte + for { + b, err := br.ReadByte() + if err == io.EOF { + break + } + if err != nil { + return "", err + } + if !(('a' <= b && b <= 'z') || ('A' <= b && b <= 'Z') || '0' <= b && b <= '9' || b == '-') { + if err := br.UnreadByte(); err != nil { + return "", err + } + break + } + ident = append(ident, b) + } + return string(ident), nil +} + +func parseByte(br *bufio.Reader, expect byte) error { + b, err := br.ReadByte() + if err != nil { + if err == io.EOF { + return fmt.Errorf("expected '%c', got EOF", expect) + } + return err + } + if b != expect { + return fmt.Errorf("expected '%c', got '%c'", expect, b) + } + return nil +} + +func parseString(br *bufio.Reader) (string, error) { + var s []rune + // read the open quote + if err := parseByte(br, '"'); err != nil { + return "", err + } + // read the string + var escaped bool + for { + r, _, err := br.ReadRune() + if err != nil { + return "", err + } + if escaped { + s = append(s, r) + escaped = false + continue + } + if r == '\\' { + escaped = true + continue + } + // closing quote + if r == '"' { + break + } + s = append(s, r) + } + return string(s), nil +} + +func skipWhite(br *bufio.Reader) error { + for { + b, err := br.ReadByte() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if b != ' ' { + return br.UnreadByte() + } + } +} + +func parseParam(br *bufio.Reader, first bool) (Param, error) { + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + if !first { + // read the comma separator + if err := parseByte(br, ','); err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + } + // read the key + key, err := parseIdent(br) + if err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + // read the equals sign + if err := parseByte(br, '='); err != nil { + return Param{}, err + } + // skip whitespace + if err := skipWhite(br); err != nil { + return Param{}, err + } + // read the value + var value string + var quote bool + if b, _ := br.Peek(1); len(b) == 1 && b[0] == '"' { + quote = true + value, err = parseString(br) + } else { + value, err = parseIdent(br) + } + if err != nil { + return Param{}, err + } + return Param{Key: key, Value: value, Quote: quote}, nil +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index fd5817b..8624939 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -3,7 +3,9 @@ package http import ( std_bufio "bufio" "context" + "crypto/rand" "encoding/base64" + "encoding/hex" "io" "net" "net/http" @@ -42,6 +44,12 @@ func HandleConnectionEx( authOk bool ) authorization := request.Header.Get("Proxy-Authorization") + if strings.HasPrefix(authorization, "Digest ") { + username, authOk = authenticator.VerifyDigest(request.Method, request.RequestURI, authorization[7:]) + if authOk { + ctx = auth.ContextWithUser(ctx, username) + } + } if strings.HasPrefix(authorization, "Basic ") { userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) userPswdArr := strings.SplitN(string(userPassword), ":", 2) @@ -56,10 +64,31 @@ func HandleConnectionEx( } if !authOk { // Since no one else is using the library, use a fixed realm until rewritten - err = responseWith( - request, http.StatusProxyAuthRequired, - "Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`, - ).Write(conn) + // define realm in common/auth package, still "sing-box" now + nonce := ""; + randomBytes := make([]byte, 16) + _, err = rand.Read(randomBytes) + if err == nil { + nonce = hex.EncodeToString(randomBytes) + } + if nonce == "" { + err = responseWithBody( + request, http.StatusProxyAuthRequired, + "Proxy authentication required", + "Content-Type", "text/plain; charset=utf-8", + "Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"", + "Connection", "close", + ).Write(conn) + } else { + err = responseWithBody( + request, http.StatusProxyAuthRequired, + "Proxy authentication required", + "Content-Type", "text/plain; charset=utf-8", + "Proxy-Authenticate", "Basic realm=\"" + auth.Realm + "\"", + "Proxy-Authenticate", "Digest realm=\"" + auth.Realm + "\", nonce=\"" + nonce + "\", qop=\"auth\", stale=false", + "Connection", "close", + ).Write(conn) + } if err != nil { return err } @@ -68,7 +97,8 @@ func HandleConnectionEx( } else if authorization != "" { return E.New("http: authentication failed, Proxy-Authorization=", authorization) } else { - return E.New("http: authentication failed, no Proxy-Authorization header") + //return E.New("http: authentication failed, no Proxy-Authorization header") + continue } } } @@ -270,3 +300,31 @@ func responseWith(request *http.Request, statusCode int, headers ...string) *htt Header: header, } } + +func responseWithBody(request *http.Request, statusCode int, body string, headers ...string) *http.Response { + var header http.Header + if len(headers) > 0 { + header = make(http.Header) + for i := 0; i < len(headers); i += 2 { + header.Add(headers[i], headers[i+1]) + } + } + var bodyReadCloser io.ReadCloser + var bodyContentLength = int64(0) + if body != "" { + bodyReadCloser = io.NopCloser(strings.NewReader(body)) + bodyContentLength = int64(len(body)) + } + return &http.Response{ + StatusCode: statusCode, + Status: http.StatusText(statusCode), + Proto: request.Proto, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + Header: header, + Body: bodyReadCloser, + ContentLength: bodyContentLength, + Close: true, + } +} +