mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-04-03 03:47:37 +03:00
Migrate to gobwas/ws
This commit is contained in:
parent
40a0b69918
commit
4d23773a25
13 changed files with 192 additions and 161 deletions
|
@ -50,7 +50,7 @@ func NewClientTransport(ctx context.Context, dialer N.Dialer, serverAddr M.Socks
|
|||
case C.V2RayTransportTypeGRPC:
|
||||
return NewGRPCClient(ctx, dialer, serverAddr, options.GRPCOptions, tlsConfig)
|
||||
case C.V2RayTransportTypeWebsocket:
|
||||
return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig), nil
|
||||
return v2raywebsocket.NewClient(ctx, dialer, serverAddr, options.WebsocketOptions, tlsConfig)
|
||||
case C.V2RayTransportTypeQUIC:
|
||||
if tlsConfig == nil {
|
||||
return nil, C.ErrTLSRequired
|
||||
|
|
|
@ -81,7 +81,7 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
|||
uri.Path = options.Path
|
||||
err := sHTTP.URLSetPath(&uri, options.Path)
|
||||
if err != nil {
|
||||
return nil, E.New("failed to set path: " + err.Error())
|
||||
return nil, E.Cause(err, "parse path")
|
||||
}
|
||||
client.url = &uri
|
||||
return client, nil
|
||||
|
|
|
@ -5,58 +5,37 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
sHTTP "github.com/sagernet/sing/protocol/http"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
)
|
||||
|
||||
var _ adapter.V2RayClientTransport = (*Client)(nil)
|
||||
|
||||
type Client struct {
|
||||
dialer *websocket.Dialer
|
||||
dialer N.Dialer
|
||||
tlsConfig tls.Config
|
||||
serverAddr M.Socksaddr
|
||||
requestURL url.URL
|
||||
requestURLString string
|
||||
headers http.Header
|
||||
maxEarlyData uint32
|
||||
earlyDataHeaderName string
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) adapter.V2RayClientTransport {
|
||||
wsDialer := &websocket.Dialer{
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 4 * 1024,
|
||||
HandshakeTimeout: time.Second * 8,
|
||||
}
|
||||
func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, options option.V2RayWebsocketOptions, tlsConfig tls.Config) (adapter.V2RayClientTransport, error) {
|
||||
if tlsConfig != nil {
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"http/1.1"})
|
||||
}
|
||||
wsDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConn, err := tls.ClientHandshake(ctx, conn, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &deadConn{tlsConn}, nil
|
||||
}
|
||||
} else {
|
||||
wsDialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &deadConn{conn}, nil
|
||||
}
|
||||
}
|
||||
var requestURL url.URL
|
||||
if tlsConfig == nil {
|
||||
|
@ -68,37 +47,68 @@ func NewClient(ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr, opt
|
|||
requestURL.Path = options.Path
|
||||
err := sHTTP.URLSetPath(&requestURL, options.Path)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, E.Cause(err, "parse path")
|
||||
}
|
||||
if !strings.HasPrefix(requestURL.Path, "/") {
|
||||
requestURL.Path = "/" + requestURL.Path
|
||||
}
|
||||
headers := make(http.Header)
|
||||
for key, value := range options.Headers {
|
||||
headers[key] = value
|
||||
if key == "Host" {
|
||||
if len(value) > 1 {
|
||||
return nil, E.New("multiple Host headers")
|
||||
}
|
||||
requestURL.Host = value[0]
|
||||
}
|
||||
}
|
||||
if headers.Get("User-Agent") == "" {
|
||||
headers.Set("User-Agent", "Go-http-client/1.1")
|
||||
}
|
||||
return &Client{
|
||||
wsDialer,
|
||||
dialer,
|
||||
tlsConfig,
|
||||
serverAddr,
|
||||
requestURL,
|
||||
requestURL.String(),
|
||||
headers,
|
||||
options.MaxEarlyData,
|
||||
options.EarlyDataHeaderName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) dialContext(ctx context.Context, requestURL *url.URL, headers http.Header) (*WebsocketConn, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.tlsConfig != nil {
|
||||
conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(C.TCPTimeout))
|
||||
var protocols []string
|
||||
if protocolHeader := headers.Get("Sec-WebSocket-Protocol"); protocolHeader != "" {
|
||||
protocols = []string{protocolHeader}
|
||||
headers.Del("Sec-WebSocket-Protocol")
|
||||
}
|
||||
reader, _, err := ws.Dialer{Header: ws.HandshakeHeaderHTTP(headers), Protocols: protocols}.Upgrade(conn, requestURL)
|
||||
conn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConn(conn, reader, nil, ws.StateClientSide), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
if c.maxEarlyData <= 0 {
|
||||
conn, response, err := c.dialer.DialContext(ctx, c.requestURLString, c.headers)
|
||||
if err == nil {
|
||||
return &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}, nil
|
||||
conn, err := c.dialContext(ctx, &c.requestURL, c.headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, wrapDialError(response, err)
|
||||
return conn, nil
|
||||
} else {
|
||||
return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func wrapDialError(response *http.Response, err error) error {
|
||||
if response == nil {
|
||||
return err
|
||||
}
|
||||
return E.Extend(err, "HTTP ", response.StatusCode, " ", response.Status)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
package v2raywebsocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -13,50 +13,96 @@ import (
|
|||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
)
|
||||
|
||||
type WebsocketConn struct {
|
||||
*websocket.Conn
|
||||
net.Conn
|
||||
*Writer
|
||||
remoteAddr net.Addr
|
||||
reader io.Reader
|
||||
state ws.State
|
||||
reader *wsutil.Reader
|
||||
controlHandler wsutil.FrameHandlerFunc
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
|
||||
func NewConn(conn net.Conn, br *bufio.Reader, remoteAddr net.Addr, state ws.State) *WebsocketConn {
|
||||
controlHandler := wsutil.ControlFrameHandler(conn, state)
|
||||
var reader io.Reader
|
||||
if br != nil && br.Buffered() > 0 {
|
||||
reader = br
|
||||
} else {
|
||||
reader = conn
|
||||
}
|
||||
return &WebsocketConn{
|
||||
Conn: wsConn,
|
||||
remoteAddr: remoteAddr,
|
||||
Writer: NewWriter(wsConn, true),
|
||||
Conn: conn,
|
||||
state: state,
|
||||
reader: &wsutil.Reader{
|
||||
Source: reader,
|
||||
State: state,
|
||||
SkipHeaderCheck: !debug.Enabled,
|
||||
OnIntermediate: controlHandler,
|
||||
},
|
||||
controlHandler: controlHandler,
|
||||
remoteAddr: remoteAddr,
|
||||
Writer: NewWriter(conn, state),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebsocketConn) Close() error {
|
||||
err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
|
||||
if err != nil {
|
||||
return c.Conn.Close()
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout))
|
||||
frame := ws.NewCloseFrame(ws.NewCloseFrameBody(
|
||||
ws.StatusNormalClosure, "",
|
||||
))
|
||||
if c.state == ws.StateClientSide {
|
||||
frame = ws.MaskFrameInPlace(frame)
|
||||
}
|
||||
ws.WriteFrame(c.Conn, frame)
|
||||
c.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WebsocketConn) Read(b []byte) (n int, err error) {
|
||||
var header ws.Header
|
||||
for {
|
||||
if c.reader == nil {
|
||||
_, c.reader, err = c.NextReader()
|
||||
n, err = c.reader.Read(b)
|
||||
if n > 0 {
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) {
|
||||
return
|
||||
}
|
||||
header, err = c.reader.NextFrame()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if header.OpCode.IsControl() {
|
||||
err = c.controlHandler(header, c.reader)
|
||||
if err != nil {
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
n, err = c.reader.Read(b)
|
||||
if E.IsMulti(err, io.EOF) {
|
||||
c.reader = nil
|
||||
continue
|
||||
}
|
||||
err = wrapError(err)
|
||||
if header.OpCode&ws.OpBinary == 0 {
|
||||
err = c.reader.Discard()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WebsocketConn) Write(p []byte) (n int, err error) {
|
||||
err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *WebsocketConn) RemoteAddr() net.Addr {
|
||||
|
@ -83,11 +129,7 @@ func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
|
|||
}
|
||||
|
||||
func (c *WebsocketConn) Upstream() any {
|
||||
return c.Conn.NetConn()
|
||||
}
|
||||
|
||||
func (c *WebsocketConn) UpstreamWriter() any {
|
||||
return c.Writer
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
type EarlyWebsocketConn struct {
|
||||
|
@ -113,8 +155,7 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
|
|||
var (
|
||||
earlyData []byte
|
||||
lateData []byte
|
||||
conn *websocket.Conn
|
||||
response *http.Response
|
||||
conn *WebsocketConn
|
||||
err error
|
||||
)
|
||||
if len(content) > int(c.maxEarlyData) {
|
||||
|
@ -128,23 +169,26 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
|
|||
if c.earlyDataHeaderName == "" {
|
||||
requestURL := c.requestURL
|
||||
requestURL.Path += earlyDataString
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
|
||||
conn, err = c.dialContext(c.ctx, &requestURL, c.headers)
|
||||
} else {
|
||||
headers := c.headers.Clone()
|
||||
headers.Set(c.earlyDataHeaderName, earlyDataString)
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
|
||||
conn, err = c.dialContext(c.ctx, &c.requestURL, headers)
|
||||
}
|
||||
} else {
|
||||
conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
|
||||
conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers)
|
||||
}
|
||||
if err != nil {
|
||||
return wrapDialError(response, err)
|
||||
return err
|
||||
}
|
||||
c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
|
||||
if len(lateData) > 0 {
|
||||
_, err = c.conn.Write(lateData)
|
||||
_, err = conn.Write(lateData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
|
||||
|
@ -230,13 +274,3 @@ func (c *EarlyWebsocketConn) Upstream() any {
|
|||
func (c *EarlyWebsocketConn) LazyHeadroom() bool {
|
||||
return c.conn == nil
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
return io.EOF
|
||||
}
|
||||
if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
package v2raywebsocket
|
||||
|
||||
import _ "unsafe"
|
||||
|
||||
//go:linkname maskBytes github.com/sagernet/websocket.maskBytes
|
||||
func maskBytes(key [4]byte, pos int, b []byte) int
|
|
@ -20,7 +20,7 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
sHttp "github.com/sagernet/sing/protocol/http"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
)
|
||||
|
||||
var _ adapter.V2RayServerTransport = (*Server)(nil)
|
||||
|
@ -58,13 +58,6 @@ func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsCon
|
|||
return server, nil
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
HandshakeTimeout: C.TCPTimeout,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" {
|
||||
if request.URL.Path != s.path {
|
||||
|
@ -95,14 +88,14 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
|||
s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data"))
|
||||
return
|
||||
}
|
||||
wsConn, err := upgrader.Upgrade(writer, request, nil)
|
||||
wsConn, reader, _, err := ws.UpgradeHTTP(request, writer)
|
||||
if err != nil {
|
||||
s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection"))
|
||||
return
|
||||
}
|
||||
var metadata M.Metadata
|
||||
metadata.Source = sHttp.SourceAddress(request)
|
||||
conn = NewServerConn(wsConn, metadata.Source.TCPAddr())
|
||||
conn = NewConn(wsConn, reader.Reader, metadata.Source.TCPAddr(), ws.StateServerSide)
|
||||
if len(earlyData) > 0 {
|
||||
conn = bufio.NewCachedConn(conn, buf.As(earlyData))
|
||||
}
|
||||
|
|
|
@ -2,36 +2,27 @@ package v2raywebsocket
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/websocket"
|
||||
"github.com/sagernet/ws"
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
*websocket.Conn
|
||||
writer N.ExtendedWriter
|
||||
isServer bool
|
||||
}
|
||||
|
||||
func NewWriter(conn *websocket.Conn, isServer bool) *Writer {
|
||||
func NewWriter(writer io.Writer, state ws.State) *Writer {
|
||||
return &Writer{
|
||||
conn,
|
||||
bufio.NewExtendedWriter(conn.NetConn()),
|
||||
isServer,
|
||||
bufio.NewExtendedWriter(writer),
|
||||
state == ws.StateServerSide,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
err = w.Conn.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
|
||||
var payloadBitLength int
|
||||
dataLen := buffer.Len()
|
||||
|
@ -52,7 +43,7 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
|
|||
}
|
||||
|
||||
header := buffer.ExtendHeader(headerLen)
|
||||
header[0] = websocket.BinaryMessage | 1<<7
|
||||
header[0] = byte(ws.OpBinary) | 0x80
|
||||
if w.isServer {
|
||||
header[1] = 0
|
||||
} else {
|
||||
|
@ -72,16 +63,12 @@ func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
|
|||
if !w.isServer {
|
||||
maskKey := rand.Uint32()
|
||||
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
|
||||
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
|
||||
ws.Cipher(data, *(*[4]byte)(header[1+payloadBitLength:]), 0)
|
||||
}
|
||||
|
||||
return w.writer.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
func (w *Writer) Upstream() any {
|
||||
return w.Conn.NetConn()
|
||||
}
|
||||
|
||||
func (w *Writer) FrontHeadroom() int {
|
||||
return 14
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue