mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
implement closing the quic client with an error
This commit is contained in:
parent
c42262c2b3
commit
08c267431b
5 changed files with 16 additions and 11 deletions
|
@ -110,8 +110,8 @@ func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the connection
|
// Close closes the connection
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close(e error) error {
|
||||||
_ = c.session.Close(nil)
|
_ = c.session.Close(e)
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,9 @@ package quic
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/protocol"
|
"github.com/lucas-clemente/quic-go/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
@ -59,6 +61,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("properly closes the client", func(done Done) {
|
It("properly closes the client", func(done Done) {
|
||||||
|
numGoRoutines := runtime.NumGoroutine()
|
||||||
startUDPConn()
|
startUDPConn()
|
||||||
var stoppedListening bool
|
var stoppedListening bool
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -67,11 +70,13 @@ var _ = Describe("Client", func() {
|
||||||
stoppedListening = true
|
stoppedListening = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := client.Close()
|
testErr := errors.New("test error")
|
||||||
|
err := client.Close(testErr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(session.closed).Should(BeTrue())
|
Eventually(session.closed).Should(BeTrue())
|
||||||
Expect(session.closeReason).To(BeNil())
|
Expect(session.closeReason).To(MatchError(testErr))
|
||||||
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
|
Eventually(func() bool { return stoppedListening }).Should(BeTrue())
|
||||||
|
Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines))
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -85,7 +90,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID))
|
Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID))
|
||||||
Expect(client.session.(*Session).version).To(Equal(client.version))
|
Expect(client.session.(*Session).version).To(Equal(client.version))
|
||||||
|
|
||||||
err = client.Close()
|
err = client.Close(nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -130,7 +135,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(session.closed).To(BeFalse())
|
Expect(session.closed).To(BeFalse())
|
||||||
Eventually(func() bool { return stoppedListening }).Should(BeFalse())
|
Eventually(func() bool { return stoppedListening }).Should(BeFalse())
|
||||||
|
|
||||||
err = client.Close()
|
err = client.Close(nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
|
@ -202,7 +207,7 @@ var _ = Describe("Client", func() {
|
||||||
// it didn't pass the version negoation packet to the session (since it has no payload)
|
// it didn't pass the version negoation packet to the session (since it has no payload)
|
||||||
Expect(session.packetCount).To(BeZero())
|
Expect(session.packetCount).To(BeZero())
|
||||||
|
|
||||||
err = client.Close()
|
err = client.Close(nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Listen()
|
err = client.Listen()
|
||||||
defer client.Close()
|
defer client.Close(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
type quicClient interface {
|
type quicClient interface {
|
||||||
OpenStream(protocol.StreamID) (utils.Stream, error)
|
OpenStream(protocol.StreamID) (utils.Stream, error)
|
||||||
Close() error
|
Close(error) error
|
||||||
Listen() error
|
Listen() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,8 @@ type mockQuicClient struct {
|
||||||
streams map[protocol.StreamID]*mockStream
|
streams map[protocol.StreamID]*mockStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockQuicClient) Close() error { panic("not implemented") }
|
func (m *mockQuicClient) Close(error) error { panic("not implemented") }
|
||||||
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
||||||
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
||||||
_, ok := m.streams[id]
|
_, ok := m.streams[id]
|
||||||
if ok {
|
if ok {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue