diff --git a/benchmark_test.go b/benchmark_test.go index 6e28d640..64285969 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -50,7 +50,7 @@ func newLinkedConnection(other *session) *linkedConnection { return conn } -func (c *linkedConnection) write(p []byte) error { +func (c *linkedConnection) Write(p []byte) error { packet := getPacketBuffer() packet = packet[:len(p)] copy(packet, p) @@ -61,8 +61,10 @@ func (c *linkedConnection) write(p []byte) error { return nil } -func (*linkedConnection) setCurrentRemoteAddr(addr net.Addr) {} -func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } +func (c *linkedConnection) Read(p []byte) (int, net.Addr, error) { panic("not implemented") } +func (*linkedConnection) SetCurrentRemoteAddr(addr net.Addr) {} +func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } +func (c *linkedConnection) Close() error { return nil } func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) { *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true diff --git a/conn.go b/conn.go index 7f81af35..0fede0ba 100644 --- a/conn.go +++ b/conn.go @@ -6,9 +6,11 @@ import ( ) type connection interface { - write([]byte) error - setCurrentRemoteAddr(net.Addr) + Write([]byte) error + Read([]byte) (int, net.Addr, error) + Close() error RemoteAddr() net.Addr + SetCurrentRemoteAddr(net.Addr) } type conn struct { @@ -20,12 +22,16 @@ type conn struct { var _ connection = &conn{} -func (c *conn) write(p []byte) error { +func (c *conn) Write(p []byte) error { _, err := c.pconn.WriteTo(p, c.currentAddr) return err } -func (c *conn) setCurrentRemoteAddr(addr net.Addr) { +func (c *conn) Read(p []byte) (int, net.Addr, error) { + return c.pconn.ReadFrom(p) +} + +func (c *conn) SetCurrentRemoteAddr(addr net.Addr) { c.mutex.Lock() c.currentAddr = addr c.mutex.Unlock() @@ -37,3 +43,7 @@ func (c *conn) RemoteAddr() net.Addr { c.mutex.RUnlock() return addr } + +func (c *conn) Close() error { + return c.pconn.Close() +} diff --git a/conn_test.go b/conn_test.go index 9838ab81..fb6a4764 100644 --- a/conn_test.go +++ b/conn_test.go @@ -42,23 +42,36 @@ var _ net.PacketConn = &mockPacketConn{} var _ = Describe("Connection", func() { var c *conn + var packetConn *mockPacketConn BeforeEach(func() { addr := &net.UDPAddr{ IP: net.IPv4(192, 168, 100, 200), Port: 1337, } + packetConn = &mockPacketConn{} c = &conn{ currentAddr: addr, - pconn: &mockPacketConn{}, + pconn: packetConn, } }) It("writes", func() { - err := c.write([]byte("foobar")) + err := c.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - Expect(c.pconn.(*mockPacketConn).dataWritten.Bytes()).To(Equal([]byte("foobar"))) - Expect(c.pconn.(*mockPacketConn).dataWrittenTo.String()).To(Equal("192.168.100.200:1337")) + Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar"))) + Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337")) + }) + + It("reads", func() { + packetConn.dataToRead = []byte("foo") + packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336} + p := make([]byte, 10) + n, raddr, err := c.Read(p) + Expect(err).ToNot(HaveOccurred()) + Expect(raddr.String()).To(Equal("127.0.0.1:1336")) + Expect(n).To(Equal(3)) + Expect(p[0:3]).To(Equal([]byte("foo"))) }) It("gets the remote address", func() { @@ -70,7 +83,13 @@ var _ = Describe("Connection", func() { IP: net.IPv4(127, 0, 0, 1), Port: 7331, } - c.setCurrentRemoteAddr(addr) + c.SetCurrentRemoteAddr(addr) Expect(c.RemoteAddr().String()).To(Equal(addr.String())) }) + + It("closes", func() { + err := c.Close() + Expect(err).ToNot(HaveOccurred()) + Expect(packetConn.closed).To(BeTrue()) + }) }) diff --git a/session.go b/session.go index 4c58025c..d9364707 100644 --- a/session.go +++ b/session.go @@ -332,7 +332,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } if s.perspective == protocol.PerspectiveServer { // update the remote address, even if unpacking failed for any other reason than a decryption error - s.conn.setCurrentRemoteAddr(p.remoteAddr) + s.conn.SetCurrentRemoteAddr(p.remoteAddr) } if err != nil { return err @@ -626,7 +626,7 @@ func (s *session) sendPacket() error { s.logPacket(packet) - err = s.conn.write(packet.raw) + err = s.conn.Write(packet.raw) putPacketBuffer(packet.raw) if err != nil { return err @@ -644,7 +644,7 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { return errors.New("Session BUG: expected packet not to be nil") } s.logPacket(packet) - return s.conn.write(packet.raw) + return s.conn.Write(packet.raw) } func (s *session) logPacket(packet *packedPacket) { @@ -723,7 +723,7 @@ func (s *session) garbageCollectStreams() { func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) - return s.conn.write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) + return s.conn.Write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending diff --git a/session_test.go b/session_test.go index 49b55d5e..1c5eae61 100644 --- a/session_test.go +++ b/session_test.go @@ -29,17 +29,19 @@ type mockConnection struct { written [][]byte } -func (m *mockConnection) write(p []byte) error { +func (m *mockConnection) Write(p []byte) error { b := make([]byte, len(p)) copy(b, p) m.written = append(m.written, b) return nil } +func (m *mockConnection) Read([]byte) (int, net.Addr, error) { panic("not implemented") } -func (m *mockConnection) setCurrentRemoteAddr(addr net.Addr) { +func (m *mockConnection) SetCurrentRemoteAddr(addr net.Addr) { m.remoteAddr = addr } func (*mockConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } +func (*mockConnection) Close() error { panic("not implemented") } type mockUnpacker struct { unpackErr error