chore: code adjustments

This commit is contained in:
Toby 2023-07-24 17:08:19 -07:00
parent cbedb27f0f
commit f0ad2f77ca
3 changed files with 30 additions and 20 deletions

View file

@ -133,14 +133,19 @@ func (c *clientImpl) connect() error {
c.conn = conn c.conn = conn
if udpEnabled { if udpEnabled {
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
go func() {
c.udpSM.Run()
// TODO: Mark connection as closed
}()
} }
return nil return nil
} }
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream()
if err != nil {
return nil, err
}
return &utils.QStream{Stream: stream}, nil
}
func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
stream, err := c.openStream() stream, err := c.openStream()
if err != nil { if err != nil {
@ -272,12 +277,3 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
} }
return io.Conn.SendMessage(buf[:msgN]) return io.Conn.SendMessage(buf[:msgN])
} }
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream()
if err != nil {
return nil, err
}
return &utils.QStream{Stream: stream}, nil
}

View file

@ -6,6 +6,7 @@ import (
"math/rand" "math/rand"
"sync" "sync"
coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/frag" "github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
@ -86,21 +87,22 @@ type udpSessionManager struct {
mutex sync.Mutex mutex sync.Mutex
m map[uint32]*udpConn m map[uint32]*udpConn
nextID uint32 nextID uint32
closed bool
} }
func newUDPSessionManager(io udpIO) *udpSessionManager { func newUDPSessionManager(io udpIO) *udpSessionManager {
return &udpSessionManager{ m := &udpSessionManager{
io: io, io: io,
m: make(map[uint32]*udpConn), m: make(map[uint32]*udpConn),
nextID: 1, nextID: 1,
} }
go m.run()
return m
} }
// Run runs the session manager main loop. func (m *udpSessionManager) run() error {
// Exit and returns error when the underlying io returns error (e.g. closed). defer m.closeCleanup()
func (m *udpSessionManager) Run() error {
defer m.cleanup()
for { for {
msg, err := m.io.ReceiveMessage() msg, err := m.io.ReceiveMessage()
if err != nil { if err != nil {
@ -110,13 +112,14 @@ func (m *udpSessionManager) Run() error {
} }
} }
func (m *udpSessionManager) cleanup() { func (m *udpSessionManager) closeCleanup() {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
for _, conn := range m.m { for _, conn := range m.m {
m.close(conn) m.close(conn)
} }
m.closed = true
} }
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
@ -142,6 +145,10 @@ func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.closed {
return nil, coreErrs.ClosedError{}
}
id := m.nextID id := m.nextID
m.nextID++ m.nextID++

View file

@ -47,6 +47,13 @@ func (c DialError) Error() string {
return "dial error: " + c.Message return "dial error: " + c.Message
} }
// ClosedError is returned when the client attempts to use a closed connection.
type ClosedError struct{}
func (c ClosedError) Error() string {
return "connection closed"
}
// ProtocolError is returned when the server/client runs into an unexpected // ProtocolError is returned when the server/client runs into an unexpected
// or malformed request/response/message. // or malformed request/response/message.
type ProtocolError struct { type ProtocolError struct {