diff --git a/.travis.yml b/.travis.yml index ef63eb87..e899d793 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,9 +14,7 @@ env: global: - TIMESCALE_FACTOR=20 matrix: - - TRAVIS_GOARCH=amd64 TESTMODE=unit - TRAVIS_GOARCH=amd64 TESTMODE=integration - - TRAVIS_GOARCH=386 TESTMODE=unit - TRAVIS_GOARCH=386 TESTMODE=integration # second part of the GOARCH workaround diff --git a/.travis/script.sh b/.travis/script.sh index f9501fbd..faea5a40 100755 --- a/.travis/script.sh +++ b/.travis/script.sh @@ -2,18 +2,12 @@ set -ex -if [ "${TESTMODE}" == "unit" ]; then - ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage integrationtests,benchmark -fi - -if [ "${TESTMODE}" == "integration" ]; then - # run benchmark tests - ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=10 - # run benchmark tests with the Go race detector - # The Go race detector only works on amd64. - if [ "${TRAVIS_GOARCH}" == 'amd64' ]; then - ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=5 - fi - # run integration tests - ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace integrationtests +# run benchmark tests +ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=10 +# run benchmark tests with the Go race detector +# The Go race detector only works on amd64. +if [ "${TRAVIS_GOARCH}" != '386' ]; then + ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=5 fi +# run integration tests +ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace integrationtests diff --git a/client_test.go b/client_test.go index 2d1b090a..6f8e8d08 100644 --- a/client_test.go +++ b/client_test.go @@ -575,7 +575,7 @@ var _ = Describe("Client", func() { _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) - Expect(cconn.(*conn).PacketConn).To(Equal(packetConn)) + Expect(cconn.(*sconn).PacketConn).To(Equal(packetConn)) Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) }) diff --git a/conn.go b/conn.go new file mode 100644 index 00000000..0f65c558 --- /dev/null +++ b/conn.go @@ -0,0 +1,50 @@ +package quic + +import ( + "io" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/utils" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type connection interface { + ReadPacket() (*receivedPacket, error) + WriteTo([]byte, net.Addr) (int, error) + LocalAddr() net.Addr + io.Closer +} + +func wrapConn(pc net.PacketConn) (connection, error) { + udpConn, ok := pc.(*net.UDPConn) + if !ok { + utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") + return &basicConn{PacketConn: pc}, nil + } + return newConn(udpConn) +} + +type basicConn struct { + net.PacketConn +} + +var _ connection = &basicConn{} + +func (c *basicConn) ReadPacket() (*receivedPacket, error) { + buffer := getPacketBuffer() + // The packet size should not exceed protocol.MaxReceivePacketSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + buffer.Data = buffer.Data[:protocol.MaxReceivePacketSize] + n, addr, err := c.PacketConn.ReadFrom(buffer.Data) + if err != nil { + return nil, err + } + return &receivedPacket{ + remoteAddr: addr, + rcvTime: time.Now(), + data: buffer.Data[:n], + buffer: buffer, + }, nil +} diff --git a/conn_ecn.go b/conn_ecn.go new file mode 100644 index 00000000..9662bc18 --- /dev/null +++ b/conn_ecn.go @@ -0,0 +1,91 @@ +// +build !windows + +package quic + +import ( + "errors" + "net" + "syscall" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +const ecnMask uint8 = 0x3 + +type ecnConn struct { + *net.UDPConn + oobBuffer []byte +} + +var _ connection = &ecnConn{} + +func newConn(c *net.UDPConn) (*ecnConn, error) { + rawConn, err := c.SyscallConn() + if err != nil { + return nil, err + } + // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. + // Try enabling receiving of ECN for both IP versions. + // We expect at least one of those syscalls to succeed. + var errIPv4, errIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errIPv4 = setRECVTOS(fd) + }); err != nil { + return nil, err + } + if err := rawConn.Control(func(fd uintptr) { + errIPv6 = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1) + }); err != nil { + return nil, err + } + switch { + case errIPv4 == nil && errIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.") + case errIPv4 == nil && errIPv6 != nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.") + case errIPv4 != nil && errIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.") + case errIPv4 != nil && errIPv6 != nil: + return nil, errors.New("activating ECN failed for both IPv4 and IPv6") + } + return &ecnConn{ + UDPConn: c, + oobBuffer: make([]byte, 128), + }, nil +} + +func (c *ecnConn) ReadPacket() (*receivedPacket, error) { + buffer := getPacketBuffer() + // The packet size should not exceed protocol.MaxReceivePacketSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + buffer.Data = buffer.Data[:protocol.MaxReceivePacketSize] + c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)] + n, oobn, _, addr, err := c.UDPConn.ReadMsgUDP(buffer.Data, c.oobBuffer) + if err != nil { + return nil, err + } + ctrlMsgs, err := syscall.ParseSocketControlMessage(c.oobBuffer[:oobn]) + if err != nil { + return nil, err + } + var ecn protocol.ECN + for _, ctrlMsg := range ctrlMsgs { + if ctrlMsg.Header.Level == syscall.IPPROTO_IP && ctrlMsg.Header.Type == msgTypeIPTOS { + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + break + } + if ctrlMsg.Header.Level == syscall.IPPROTO_IPV6 && ctrlMsg.Header.Type == syscall.IPV6_TCLASS { + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + break + } + } + return &receivedPacket{ + remoteAddr: addr, + rcvTime: time.Now(), + data: buffer.Data[:n], + ecn: ecn, + buffer: buffer, + }, nil +} diff --git a/conn_ecn_test.go b/conn_ecn_test.go new file mode 100644 index 00000000..7498e1ef --- /dev/null +++ b/conn_ecn_test.go @@ -0,0 +1,128 @@ +// +build !windows + +package quic + +import ( + "net" + "syscall" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Basic Conn Test", func() { + Context("ECN conn", func() { + runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { + addr, err := net.ResolveUDPAddr(network, address) + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP(network, addr) + Expect(err).ToNot(HaveOccurred()) + ecnConn, err := newConn(udpConn) + Expect(err).ToNot(HaveOccurred()) + + packetChan := make(chan *receivedPacket) + go func() { + defer GinkgoRecover() + for { + p, err := ecnConn.ReadPacket() + if err != nil { + return + } + packetChan <- p + } + }() + + return udpConn, packetChan + } + + sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { + conn, err := net.DialUDP(network, nil, addr) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + rawConn, err := conn.SyscallConn() + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + ExpectWithOffset(1, rawConn.Control(func(fd uintptr) { + setECN(fd) + })).To(Succeed()) + _, err = conn.Write([]byte("foobar")) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return conn.LocalAddr() + } + + It("reads ECN flags on IPv4", func() { + conn, packetChan := runServer("udp4", "localhost:0") + defer conn.Close() + + sentFrom := sendPacketWithECN( + "udp4", + conn.LocalAddr().(*net.UDPAddr), + func(fd uintptr) { + Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, 2)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.ecn).To(Equal(protocol.ECT0)) + }) + + It("reads ECN flags on IPv6", func() { + conn, packetChan := runServer("udp6", "[::]:0") + defer conn.Close() + + sentFrom := sendPacketWithECN( + "udp6", + conn.LocalAddr().(*net.UDPAddr), + func(fd uintptr) { + Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 3)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.remoteAddr).To(Equal(sentFrom)) + Expect(p.ecn).To(Equal(protocol.ECNCE)) + }) + + It("reads ECN flags on a connection that supports both IPv4 and IPv6", func() { + conn, packetChan := runServer("udp", "0.0.0.0:0") + defer conn.Close() + port := conn.LocalAddr().(*net.UDPAddr).Port + + // IPv4 + sendPacketWithECN( + "udp4", + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, + func(fd uintptr) { + Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, 3)).To(Succeed()) + }, + ) + + var p *receivedPacket + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) + Expect(p.ecn).To(Equal(protocol.ECNCE)) + + // IPv6 + sendPacketWithECN( + "udp6", + &net.UDPAddr{IP: net.IPv6loopback, Port: port}, + func(fd uintptr) { + Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 1)).To(Succeed()) + }, + ) + + Eventually(packetChan).Should(Receive(&p)) + Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) + Expect(p.ecn).To(Equal(protocol.ECT1)) + }) + }) +}) diff --git a/conn_helper_darwin.go b/conn_helper_darwin.go new file mode 100644 index 00000000..042221c1 --- /dev/null +++ b/conn_helper_darwin.go @@ -0,0 +1,14 @@ +// +build darwin + +package quic + +import "syscall" + +const ( + ip_recvtos = 27 + msgTypeIPTOS = ip_recvtos +) + +func setRECVTOS(fd uintptr) error { + return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1) +} diff --git a/conn_helper_generic.go b/conn_helper_generic.go new file mode 100644 index 00000000..1935e7dd --- /dev/null +++ b/conn_helper_generic.go @@ -0,0 +1,11 @@ +// +build !darwin,!windows + +package quic + +import "syscall" + +const msgTypeIPTOS = syscall.IP_TOS + +func setRECVTOS(fd uintptr) error { + return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_RECVTOS, 1) +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 00000000..1043535e --- /dev/null +++ b/conn_test.go @@ -0,0 +1,26 @@ +package quic + +import ( + "net" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Basic Conn Test", func() { + It("reads a packet", func() { + c := newMockPacketConn() + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + c.dataReadFrom = addr + c.dataToRead <- []byte("foobar") + + conn, err := wrapConn(c) + Expect(err).ToNot(HaveOccurred()) + p, err := conn.ReadPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.data).To(Equal([]byte("foobar"))) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(100*time.Millisecond))) + Expect(p.remoteAddr).To(Equal(addr)) + }) +}) diff --git a/conn_windows.go b/conn_windows.go new file mode 100644 index 00000000..45acf096 --- /dev/null +++ b/conn_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package quic + +import "net" + +func newConn(c *net.UDPConn) (connection, error) { + return &basicConn{PacketConn: c}, nil +} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 014b371d..ef861c2b 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -37,10 +37,10 @@ func (t PacketType) String() string { type ECN uint8 const ( - ECNNon ECN = iota - ECT0 - ECT1 - ECNCE + ECNNon ECN = iota // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 ) // A ByteCount in QUIC diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index a89f732f..117405e4 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -15,4 +15,11 @@ var _ = Describe("Protocol", func() { Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) }) }) + + It("converts ECN bits from the IP header wire to the correct types", func() { + Expect(ECN(0)).To(Equal(ECNNon)) + Expect(ECN(0b00000010)).To(Equal(ECT0)) + Expect(ECN(0b00000001)).To(Equal(ECT1)) + Expect(ECN(0b00000011)).To(Equal(ECNCE)) + }) }) diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index bcdcdcf8..90784d8f 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -51,7 +51,7 @@ func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{ } // RemoveConn mocks base method -func (m *MockMultiplexer) RemoveConn(arg0 net.PacketConn) error { +func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoveConn", arg0) ret0, _ := ret[0].(error) diff --git a/multiplexer.go b/multiplexer.go index 51eadfe0..129db41f 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -15,9 +15,13 @@ var ( connMuxer multiplexer ) +type indexableConn interface { + LocalAddr() net.Addr +} + type multiplexer interface { AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) - RemoveConn(net.PacketConn) error + RemoveConn(indexableConn) error } type connManager struct { @@ -33,7 +37,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) packetHandlerManager // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests logger utils.Logger } @@ -63,7 +67,10 @@ func (m *connMultiplexer) AddConn( connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() p, ok := m.conns[connIndex] if !ok { - manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) + manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) + if err != nil { + return nil, err + } p = connManager{ connIDLen: connIDLen, statelessResetKey: statelessResetKey, @@ -85,7 +92,7 @@ func (m *connMultiplexer) AddConn( return p.manager, nil } -func (m *connMultiplexer) RemoveConn(c net.PacketConn) error { +func (m *connMultiplexer) RemoveConn(c indexableConn) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/packet_handler_map.go b/packet_handler_map.go index 781f430e..1d18b5f2 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -31,7 +31,7 @@ func (e statelessResetErr) Error() string { type packetHandlerMap struct { mutex sync.RWMutex - conn net.PacketConn + conn connection connIDLen int handlers map[string] /* string(ConnectionID)*/ packetHandler @@ -54,12 +54,16 @@ type packetHandlerMap struct { var _ packetHandlerManager = &packetHandlerMap{} func newPacketHandlerMap( - conn net.PacketConn, + c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer, logger utils.Logger, -) packetHandlerManager { +) (packetHandlerManager, error) { + conn, err := wrapConn(c) + if err != nil { + return nil, err + } m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, @@ -78,7 +82,7 @@ func newPacketHandlerMap( go m.logUsage() } - return m + return m, nil } func (h *packetHandlerMap) logUsage() { @@ -253,55 +257,40 @@ func (h *packetHandlerMap) close(e error) error { func (h *packetHandlerMap) listen() { defer close(h.listening) for { - buffer := getPacketBuffer() - data := buffer.Data[:protocol.MaxReceivePacketSize] - // The packet size should not exceed protocol.MaxReceivePacketSize bytes - // If it does, we only read a truncated packet, which will then end up undecryptable - n, addr, err := h.conn.ReadFrom(data) + p, err := h.conn.ReadPacket() if err != nil { h.close(err) return } - h.handlePacket(addr, buffer, data[:n]) + h.handlePacket(p) } } -func (h *packetHandlerMap) handlePacket( - addr net.Addr, - buffer *packetBuffer, - data []byte, -) { - connID, err := wire.ParseConnectionID(data, h.connIDLen) +func (h *packetHandlerMap) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, h.connIDLen) if err != nil { - buffer.MaybeRelease() - h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err) + h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) if h.tracer != nil { - h.tracer.DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) + h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } + p.buffer.MaybeRelease() return } - rcvTime := time.Now() h.mutex.RLock() defer h.mutex.RUnlock() - if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset { + if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { return } handler, handlerFound := h.handlers[string(connID)] - p := &receivedPacket{ - remoteAddr: addr, - rcvTime: rcvTime, - buffer: buffer, - data: data, - } if handlerFound { // existing session handler.handlePacket(p) return } - if data[0]&0x80 == 0 { + if p.data[0]&0x80 == 0 { go h.maybeSendStatelessReset(p, connID) return } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 5e7e1a2b..13d1d312 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -56,7 +56,9 @@ var _ = Describe("Packet Handler Map", func() { JustBeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger).(*packetHandlerMap) + phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) + Expect(err).ToNot(HaveOccurred()) + handler = phm.(*packetHandlerMap) }) AfterEach(func() { @@ -130,7 +132,11 @@ var _ = Describe("Packet Handler Map", func() { It("drops unparseable packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) - handler.handlePacket(addr, getPacketBuffer(), []byte{0, 1, 2, 3}) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: []byte{0, 1, 2, 3}, + }) }) It("deletes removed sessions immediately", func() { @@ -138,7 +144,7 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - handler.handlePacket(nil, nil, getPacket(connID)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) // don't EXPECT any calls to handlePacket of the MockPacketHandler }) @@ -149,7 +155,7 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(connID, sess) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(nil, nil, getPacket(connID)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) // don't EXPECT any calls to handlePacket of the MockPacketHandler }) @@ -163,13 +169,13 @@ var _ = Describe("Packet Handler Map", func() { }) handler.Add(connID, packetHandler) handler.Retire(connID) - handler.handlePacket(nil, nil, getPacket(connID)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) Eventually(handled).Should(BeClosed()) }) It("drops packets for unknown receivers", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.handlePacket(nil, nil, getPacket(connID)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) }) It("closes the packet handlers when reading from the conn fails", func() { @@ -210,7 +216,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(cid).To(Equal(connID)) }) handler.SetServer(server) - handler.handlePacket(nil, nil, p) + handler.handlePacket(&receivedPacket{data: p}) }) It("closes all server sessions", func() { @@ -232,7 +238,7 @@ var _ = Describe("Packet Handler Map", func() { // don't EXPECT any calls to server.handlePacket handler.SetServer(server) handler.CloseServer() - handler.handlePacket(nil, nil, p) + handler.handlePacket(&receivedPacket{data: p}) }) }) @@ -297,7 +303,7 @@ var _ = Describe("Packet Handler Map", func() { p = append(p, token[:]...) time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(nil, nil, p) + handler.handlePacket(&receivedPacket{data: p}) }) It("ignores packets too small to contain a stateless reset", func() { @@ -332,7 +338,11 @@ var _ = Describe("Packet Handler Map", func() { It("sends stateless resets", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(addr, getPacketBuffer(), p) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) var reset mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reset)) Expect(reset.to).To(Equal(addr)) @@ -343,7 +353,11 @@ var _ = Describe("Packet Handler Map", func() { It("doesn't send stateless resets for small packets", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(addr, getPacketBuffer(), p) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) Consistently(conn.dataWritten).ShouldNot(Receive()) }) }) @@ -352,7 +366,11 @@ var _ = Describe("Packet Handler Map", func() { It("doesn't send stateless resets", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(addr, getPacketBuffer(), p) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) Consistently(conn.dataWritten).ShouldNot(Receive()) }) }) diff --git a/send_conn.go b/send_conn.go index 1f775e29..3633f49f 100644 --- a/send_conn.go +++ b/send_conn.go @@ -12,23 +12,23 @@ type sendConn interface { RemoteAddr() net.Addr } -type conn struct { +type sconn struct { net.PacketConn remoteAddr net.Addr } -var _ sendConn = &conn{} +var _ sendConn = &sconn{} func newSendConn(c net.PacketConn, remote net.Addr) sendConn { - return &conn{PacketConn: c, remoteAddr: remote} + return &sconn{PacketConn: c, remoteAddr: remote} } -func (c *conn) Write(p []byte) error { +func (c *sconn) Write(p []byte) error { _, err := c.PacketConn.WriteTo(p, c.remoteAddr) return err } -func (c *conn) RemoteAddr() net.Addr { +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } diff --git a/session.go b/session.go index 1b5bcc74..d768d8e4 100644 --- a/session.go +++ b/session.go @@ -855,7 +855,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / return false } - if err := s.handleUnpackedPacket(packet, p.rcvTime, p.Size()); err != nil { + if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { s.closeLocal(err) return false } @@ -973,6 +973,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { func (s *session) handleUnpackedPacket( packet *unpackedPacket, + ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount, // only for logging ) error { @@ -1070,7 +1071,7 @@ func (s *session) handleUnpackedPacket( } } - return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, protocol.ECNNon, packet.encryptionLevel, rcvTime, isAckEliciting) + return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) } func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { diff --git a/session_test.go b/session_test.go index 056f9fde..2e382991 100644 --- a/session_test.go +++ b/session_test.go @@ -756,6 +756,7 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, } packet := getPacket(hdr, nil) + packet.ecn = protocol.ECNCE rcvTime := time.Now().Add(-10 * time.Second) unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ packetNumber: 0x1337, @@ -766,7 +767,7 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.EncryptionInitial, rcvTime, false), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false), ) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime @@ -785,6 +786,7 @@ var _ = Describe("Session", func() { buf := &bytes.Buffer{} Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed()) packet := getPacket(hdr, nil) + packet.ecn = protocol.ECT1 unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ packetNumber: 0x1337, encryptionLevel: protocol.Encryption1RTT, @@ -794,7 +796,7 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.Encryption1RTT, rcvTime, true), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true), ) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime