Compare commits

...

3 commits
dev ... v0.5.2

Author SHA1 Message Date
世界
4216c14cf2
Fix syscall packet read waiter for Windows 2025-02-28 12:14:09 +08:00
世界
8c0bf1c05e
Fix clear lru cache 2024-11-18 12:37:59 +08:00
世界
ad36d3be6d
http: Fix proxying websocket 2024-11-18 12:20:28 +08:00
4 changed files with 243 additions and 75 deletions

View file

@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:

View file

@ -261,9 +261,15 @@ func (c *LruCache[K, V]) Delete(key K) {
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
for element := c.lru.Front(); element != nil; element = element.Next() {
c.deleteElement(element)
if c.onEvict != nil {
for le := c.lru.Front(); le != nil; le = le.Next() {
c.onEvict(le.Value.key, le.Value.value)
}
}
c.lru.Init()
c.cache = make(map[K]*list.Element[*entry[K, V]])
}
func (c *LruCache[K, V]) maybeDeleteOldest() {

113
common/cache/lrucache_test.go vendored Normal file
View file

@ -0,0 +1,113 @@
package cache_test
import (
"testing"
"time"
"github.com/sagernet/sing/common/cache"
"github.com/stretchr/testify/require"
)
func TestLRUCache(t *testing.T) {
t.Run("basic operations", func(t *testing.T) {
t.Parallel()
c := cache.New[string, int]()
c.Store("key1", 1)
value, exists := c.Load("key1")
require.True(t, exists)
require.Equal(t, 1, value)
value, exists = c.Load("missing")
require.False(t, exists)
require.Zero(t, value)
c.Delete("key1")
_, exists = c.Load("key1")
require.False(t, exists)
})
t.Run("max size", func(t *testing.T) {
t.Parallel()
c := cache.New[string, int](cache.WithSize[string, int](2))
c.Store("key1", 1)
c.Store("key2", 2)
c.Store("key3", 3)
_, exists := c.Load("key1")
require.False(t, exists)
value, exists := c.Load("key2")
require.True(t, exists)
require.Equal(t, 2, value)
})
t.Run("expiration", func(t *testing.T) {
t.Parallel()
c := cache.New[string, int](cache.WithAge[string, int](1))
c.Store("key1", 1)
value, exists := c.Load("key1")
require.True(t, exists)
require.Equal(t, 1, value)
time.Sleep(time.Second * 2)
value, exists = c.Load("key1")
require.False(t, exists)
require.Zero(t, value)
})
t.Run("clear", func(t *testing.T) {
t.Parallel()
evicted := make(map[string]int)
c := cache.New[string, int](
cache.WithEvict[string, int](func(key string, value int) {
evicted[key] = value
}),
)
c.Store("key1", 1)
c.Store("key2", 2)
c.Clear()
require.Equal(t, map[string]int{"key1": 1, "key2": 2}, evicted)
_, exists := c.Load("key1")
require.False(t, exists)
})
t.Run("load or store", func(t *testing.T) {
t.Parallel()
c := cache.New[string, int]()
value, loaded := c.LoadOrStore("key1", func() int { return 1 })
require.False(t, loaded)
require.Equal(t, 1, value)
value, loaded = c.LoadOrStore("key1", func() int { return 2 })
require.True(t, loaded)
require.Equal(t, 1, value)
})
t.Run("update age on get", func(t *testing.T) {
t.Parallel()
c := cache.New[string, int](
cache.WithAge[string, int](5),
cache.WithUpdateAgeOnGet[string, int](),
)
c.Store("key1", 1)
time.Sleep(time.Second * 3)
_, exists := c.Load("key1")
require.True(t, exists)
time.Sleep(time.Second * 3)
_, exists = c.Load("key1")
require.True(t, exists)
})
}

View file

@ -4,6 +4,7 @@ import (
std_bufio "bufio"
"context"
"encoding/base64"
"io"
"net"
"net/http"
"strings"
@ -28,7 +29,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
if err != nil {
return E.Cause(err, "read http request")
}
if authenticator != nil {
var (
username string
@ -72,11 +72,15 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
}
if request.Method == "CONNECT" {
portStr := request.URL.Port()
if portStr == "" {
portStr = "80"
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
if destination.Port == 0 {
switch request.URL.Scheme {
case "https", "wss":
destination.Port = 443
default:
destination.Port = 80
}
}
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
if err != nil {
return E.Cause(err, "write http response")
@ -96,74 +100,119 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
requestConn = conn
}
return handler.NewConnection(ctx, requestConn, metadata)
}
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
request.RequestURI = ""
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
if hostStr := request.Header.Get("Host"); hostStr != "" {
if hostStr != request.URL.Host {
request.Host = hostStr
} else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" {
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
if destination.Port == 0 {
switch request.URL.Scheme {
case "https", "wss":
destination.Port = 443
default:
destination.Port = 80
}
}
metadata.Protocol = "http"
metadata.Destination = destination
serverConn, clientConn := pipe.Pipe()
go func() {
err := handler.NewConnection(ctx, clientConn, metadata)
if err != nil {
common.Close(serverConn, clientConn)
}
}()
err = request.Write(serverConn)
if err != nil {
return E.Cause(err, "http: write upgrade request")
}
if reader.Buffered() > 0 {
_, err = io.CopyN(serverConn, reader, int64(reader.Buffered()))
if err != nil {
return err
}
}
return bufio.CopyConn(ctx, conn, serverConn)
} else {
err = handleHTTPConnection(ctx, handler, conn, request, metadata)
if err != nil {
return err
}
}
}
}
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
func handleHTTPConnection(
ctx context.Context,
//nolint:staticcheck
handler N.TCPConnectionHandler,
conn net.Conn,
request *http.Request,
metadata M.Metadata,
) error {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
request.RequestURI = ""
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
innerErr.Store(hErr)
common.Close(input, output)
}
}()
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
if hostStr := request.Header.Get("Host"); hostStr != "" {
if hostStr != request.URL.Host {
request.Host = hostStr
}
}
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
innerErr.Store(hErr)
common.Close(input, output)
}
}()
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
}
return nil
}
func removeHopByHopHeaders(header http.Header) {