mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
return packet handling and connection errors in Dial
This commit is contained in:
parent
8bfeb2ea8d
commit
f53055b9a0
2 changed files with 60 additions and 30 deletions
32
client.go
32
client.go
|
@ -14,8 +14,9 @@ import (
|
|||
)
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
connStateChangeCond sync.Cond
|
||||
mutex sync.Mutex
|
||||
connStateChangeOrErrCond sync.Cond
|
||||
listenErr error
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
@ -55,7 +56,7 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
|||
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
|
||||
}
|
||||
|
||||
c.connStateChangeCond.L = &c.mutex
|
||||
c.connStateChangeOrErrCond.L = &c.mutex
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
|
@ -67,16 +68,20 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
|||
go c.listen()
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
for {
|
||||
if c.listenErr != nil {
|
||||
return nil, c.listenErr
|
||||
}
|
||||
if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated {
|
||||
break
|
||||
}
|
||||
if c.config.ConnState == nil && c.connState == ConnStateForwardSecure {
|
||||
break
|
||||
}
|
||||
c.connStateChangeCond.Wait()
|
||||
c.connStateChangeOrErrCond.Wait()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.session, nil
|
||||
}
|
||||
|
@ -98,16 +103,20 @@ func DialAddr(hostname string, config *Config) (Session, error) {
|
|||
|
||||
// Listen listens
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
for {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
data := getPacketBuffer()
|
||||
data = data[:protocol.MaxPacketSize]
|
||||
|
||||
n, addr, err := c.conn.Read(data)
|
||||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.session.Close(err)
|
||||
}
|
||||
return
|
||||
break
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
|
@ -115,9 +124,14 @@ func (c *client) listen() {
|
|||
if err != nil {
|
||||
utils.Errorf("error handling packet: %s", err.Error())
|
||||
c.session.Close(err)
|
||||
return
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.listenErr = err
|
||||
c.connStateChangeOrErrCond.Signal()
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
|
@ -145,7 +159,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
|||
if !hdr.VersionFlag && c.connState == ConnStateInitial {
|
||||
c.mutex.Lock()
|
||||
c.connState = ConnStateVersionNegotiated
|
||||
c.connStateChangeCond.Signal()
|
||||
c.connStateChangeOrErrCond.Signal()
|
||||
c.mutex.Unlock()
|
||||
if c.config.ConnState != nil {
|
||||
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
|
||||
|
|
|
@ -48,28 +48,44 @@ var _ = Describe("Client", func() {
|
|||
}
|
||||
})
|
||||
|
||||
It("creates a new client", func() {
|
||||
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||
var err error
|
||||
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
||||
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
||||
})
|
||||
|
||||
// TODO: actually test this
|
||||
// now we're only testing that Dial doesn't return directly after version negotiation
|
||||
It("only returns once a forward-secure connection is established if no ConnState is defined", func() {
|
||||
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||
config.ConnState = nil
|
||||
var dialReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Context("Dialing", func() {
|
||||
It("creates a new client", func() {
|
||||
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||
var err error
|
||||
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
dialReturned = true
|
||||
}()
|
||||
Consistently(func() bool { return dialReturned }).Should(BeFalse())
|
||||
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
|
||||
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
||||
})
|
||||
|
||||
It("errors when receiving an invalid first packet from the server", func() {
|
||||
packetConn.dataToRead = []byte{0xff}
|
||||
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(sess).To(BeNil())
|
||||
})
|
||||
|
||||
It("errors when receiving an error from the connection", func() {
|
||||
testErr := errors.New("connection error")
|
||||
packetConn.readErr = testErr
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
// TODO: actually test this
|
||||
// now we're only testing that Dial doesn't return directly after version negotiation
|
||||
It("only returns once a forward-secure connection is established if no ConnState is defined", func() {
|
||||
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||
config.ConnState = nil
|
||||
var dialReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
dialReturned = true
|
||||
}()
|
||||
Consistently(func() bool { return dialReturned }).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
It("errors on invalid public header", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue