mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 04:47:40 +03:00
Compare commits
3 commits
Author | SHA1 | Date | |
---|---|---|---|
|
4216c14cf2 | ||
|
8c0bf1c05e | ||
|
ad36d3be6d |
4 changed files with 243 additions and 75 deletions
|
@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
|
|||
var readN int
|
||||
var from windows.Sockaddr
|
||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||
//goland:noinspection GoDirectComparisonOfErrors
|
||||
if w.readErr != nil {
|
||||
buffer.Release()
|
||||
return w.readErr != windows.WSAEWOULDBLOCK
|
||||
}
|
||||
if readN > 0 {
|
||||
buffer.Truncate(readN)
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||
return false
|
||||
}
|
||||
w.options.PostReturn(buffer)
|
||||
w.buffer = buffer
|
||||
if from != nil {
|
||||
switch fromAddr := from.(type) {
|
||||
case *windows.SockaddrInet4:
|
||||
|
|
10
common/cache/lrucache.go
vendored
10
common/cache/lrucache.go
vendored
|
@ -261,9 +261,15 @@ func (c *LruCache[K, V]) Delete(key K) {
|
|||
func (c *LruCache[K, V]) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for element := c.lru.Front(); element != nil; element = element.Next() {
|
||||
c.deleteElement(element)
|
||||
|
||||
if c.onEvict != nil {
|
||||
for le := c.lru.Front(); le != nil; le = le.Next() {
|
||||
c.onEvict(le.Value.key, le.Value.value)
|
||||
}
|
||||
}
|
||||
|
||||
c.lru.Init()
|
||||
c.cache = make(map[K]*list.Element[*entry[K, V]])
|
||||
}
|
||||
|
||||
func (c *LruCache[K, V]) maybeDeleteOldest() {
|
||||
|
|
113
common/cache/lrucache_test.go
vendored
Normal file
113
common/cache/lrucache_test.go
vendored
Normal file
|
@ -0,0 +1,113 @@
|
|||
package cache_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/cache"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLRUCache(t *testing.T) {
|
||||
t.Run("basic operations", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := cache.New[string, int]()
|
||||
|
||||
c.Store("key1", 1)
|
||||
value, exists := c.Load("key1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, 1, value)
|
||||
|
||||
value, exists = c.Load("missing")
|
||||
require.False(t, exists)
|
||||
require.Zero(t, value)
|
||||
|
||||
c.Delete("key1")
|
||||
_, exists = c.Load("key1")
|
||||
require.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("max size", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := cache.New[string, int](cache.WithSize[string, int](2))
|
||||
|
||||
c.Store("key1", 1)
|
||||
c.Store("key2", 2)
|
||||
c.Store("key3", 3)
|
||||
|
||||
_, exists := c.Load("key1")
|
||||
require.False(t, exists)
|
||||
|
||||
value, exists := c.Load("key2")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, 2, value)
|
||||
})
|
||||
|
||||
t.Run("expiration", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := cache.New[string, int](cache.WithAge[string, int](1))
|
||||
|
||||
c.Store("key1", 1)
|
||||
|
||||
value, exists := c.Load("key1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, 1, value)
|
||||
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
value, exists = c.Load("key1")
|
||||
require.False(t, exists)
|
||||
require.Zero(t, value)
|
||||
})
|
||||
|
||||
t.Run("clear", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
evicted := make(map[string]int)
|
||||
c := cache.New[string, int](
|
||||
cache.WithEvict[string, int](func(key string, value int) {
|
||||
evicted[key] = value
|
||||
}),
|
||||
)
|
||||
|
||||
c.Store("key1", 1)
|
||||
c.Store("key2", 2)
|
||||
|
||||
c.Clear()
|
||||
|
||||
require.Equal(t, map[string]int{"key1": 1, "key2": 2}, evicted)
|
||||
_, exists := c.Load("key1")
|
||||
require.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("load or store", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := cache.New[string, int]()
|
||||
|
||||
value, loaded := c.LoadOrStore("key1", func() int { return 1 })
|
||||
require.False(t, loaded)
|
||||
require.Equal(t, 1, value)
|
||||
|
||||
value, loaded = c.LoadOrStore("key1", func() int { return 2 })
|
||||
require.True(t, loaded)
|
||||
require.Equal(t, 1, value)
|
||||
})
|
||||
|
||||
t.Run("update age on get", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := cache.New[string, int](
|
||||
cache.WithAge[string, int](5),
|
||||
cache.WithUpdateAgeOnGet[string, int](),
|
||||
)
|
||||
|
||||
c.Store("key1", 1)
|
||||
|
||||
time.Sleep(time.Second * 3)
|
||||
_, exists := c.Load("key1")
|
||||
require.True(t, exists)
|
||||
|
||||
time.Sleep(time.Second * 3)
|
||||
_, exists = c.Load("key1")
|
||||
require.True(t, exists)
|
||||
})
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
std_bufio "bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -28,7 +29,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
|
|||
if err != nil {
|
||||
return E.Cause(err, "read http request")
|
||||
}
|
||||
|
||||
if authenticator != nil {
|
||||
var (
|
||||
username string
|
||||
|
@ -72,11 +72,15 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
|
|||
}
|
||||
|
||||
if request.Method == "CONNECT" {
|
||||
portStr := request.URL.Port()
|
||||
if portStr == "" {
|
||||
portStr = "80"
|
||||
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
|
||||
if destination.Port == 0 {
|
||||
switch request.URL.Scheme {
|
||||
case "https", "wss":
|
||||
destination.Port = 443
|
||||
default:
|
||||
destination.Port = 80
|
||||
}
|
||||
}
|
||||
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
|
||||
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
|
||||
if err != nil {
|
||||
return E.Cause(err, "write http response")
|
||||
|
@ -96,74 +100,119 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
|
|||
requestConn = conn
|
||||
}
|
||||
return handler.NewConnection(ctx, requestConn, metadata)
|
||||
}
|
||||
|
||||
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||
request.RequestURI = ""
|
||||
|
||||
removeHopByHopHeaders(request.Header)
|
||||
removeExtraHTTPHostPort(request)
|
||||
|
||||
if hostStr := request.Header.Get("Host"); hostStr != "" {
|
||||
if hostStr != request.URL.Host {
|
||||
request.Host = hostStr
|
||||
} else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" {
|
||||
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
|
||||
if destination.Port == 0 {
|
||||
switch request.URL.Scheme {
|
||||
case "https", "wss":
|
||||
destination.Port = 443
|
||||
default:
|
||||
destination.Port = 80
|
||||
}
|
||||
}
|
||||
metadata.Protocol = "http"
|
||||
metadata.Destination = destination
|
||||
serverConn, clientConn := pipe.Pipe()
|
||||
go func() {
|
||||
err := handler.NewConnection(ctx, clientConn, metadata)
|
||||
if err != nil {
|
||||
common.Close(serverConn, clientConn)
|
||||
}
|
||||
}()
|
||||
err = request.Write(serverConn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "http: write upgrade request")
|
||||
}
|
||||
if reader.Buffered() > 0 {
|
||||
_, err = io.CopyN(serverConn, reader, int64(reader.Buffered()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return bufio.CopyConn(ctx, conn, serverConn)
|
||||
} else {
|
||||
err = handleHTTPConnection(ctx, handler, conn, request, metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if request.URL.Scheme == "" || request.URL.Host == "" {
|
||||
return responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
}
|
||||
func handleHTTPConnection(
|
||||
ctx context.Context,
|
||||
//nolint:staticcheck
|
||||
handler N.TCPConnectionHandler,
|
||||
conn net.Conn,
|
||||
request *http.Request,
|
||||
metadata M.Metadata,
|
||||
) error {
|
||||
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||
request.RequestURI = ""
|
||||
|
||||
var innerErr atomic.TypedValue[error]
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
metadata.Destination = M.ParseSocksaddr(address)
|
||||
metadata.Protocol = "http"
|
||||
input, output := pipe.Pipe()
|
||||
go func() {
|
||||
hErr := handler.NewConnection(ctx, output, metadata)
|
||||
if hErr != nil {
|
||||
innerErr.Store(hErr)
|
||||
common.Close(input, output)
|
||||
}
|
||||
}()
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
requestCtx, cancel := context.WithCancel(ctx)
|
||||
response, err := httpClient.Do(request.WithContext(requestCtx))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
|
||||
}
|
||||
removeHopByHopHeaders(request.Header)
|
||||
removeExtraHTTPHostPort(request)
|
||||
|
||||
removeHopByHopHeaders(response.Header)
|
||||
|
||||
if keepAlive {
|
||||
response.Header.Set("Proxy-Connection", "keep-alive")
|
||||
response.Header.Set("Connection", "keep-alive")
|
||||
response.Header.Set("Keep-Alive", "timeout=4")
|
||||
}
|
||||
|
||||
response.Close = !keepAlive
|
||||
|
||||
err = response.Write(conn)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return E.Errors(innerErr.Load(), err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
if !keepAlive {
|
||||
return conn.Close()
|
||||
if hostStr := request.Header.Get("Host"); hostStr != "" {
|
||||
if hostStr != request.URL.Host {
|
||||
request.Host = hostStr
|
||||
}
|
||||
}
|
||||
|
||||
if request.URL.Scheme == "" || request.URL.Host == "" {
|
||||
return responseWith(request, http.StatusBadRequest).Write(conn)
|
||||
}
|
||||
|
||||
var innerErr atomic.TypedValue[error]
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableCompression: true,
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
metadata.Destination = M.ParseSocksaddr(address)
|
||||
metadata.Protocol = "http"
|
||||
input, output := pipe.Pipe()
|
||||
go func() {
|
||||
hErr := handler.NewConnection(ctx, output, metadata)
|
||||
if hErr != nil {
|
||||
innerErr.Store(hErr)
|
||||
common.Close(input, output)
|
||||
}
|
||||
}()
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
requestCtx, cancel := context.WithCancel(ctx)
|
||||
response, err := httpClient.Do(request.WithContext(requestCtx))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(response.Header)
|
||||
|
||||
if keepAlive {
|
||||
response.Header.Set("Proxy-Connection", "keep-alive")
|
||||
response.Header.Set("Connection", "keep-alive")
|
||||
response.Header.Set("Keep-Alive", "timeout=4")
|
||||
}
|
||||
|
||||
response.Close = !keepAlive
|
||||
|
||||
err = response.Write(conn)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return E.Errors(innerErr.Load(), err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
if !keepAlive {
|
||||
return conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeHopByHopHeaders(header http.Header) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue