diff --git a/connection.go b/connection.go index f234ea3f..1920bb2a 100644 --- a/connection.go +++ b/connection.go @@ -278,6 +278,7 @@ var newConnection = func( getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, clientAddressValidated, + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, @@ -385,7 +386,8 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, - false, /* has no effect */ + false, // has no effect + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 2c7cc4fc..036ed5ea 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -14,10 +14,11 @@ func NewAckHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return sph, newReceivedPacketHandler(sph, rttStats, logger) } diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 0ed929fa..7fd071e6 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -62,8 +62,8 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro if ackEliciting { h.maybeQueueACK(pn, rcvTime, isMissing) } + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE. switch ecn { - case protocol.ECNNon: case protocol.ECT0: h.ect0++ case protocol.ECT1: diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 3b972b66..da5d0718 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -92,6 +92,8 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time + enableECN bool + perspective protocol.Perspective tracer logging.ConnectionTracer @@ -110,6 +112,7 @@ func newSentPacketHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, @@ -130,6 +133,7 @@ func newSentPacketHandler( appDataPackets: newPacketNumberSpace(0, true), rttStats: rttStats, congestion: congestion, + enableECN: enableECN, perspective: pers, tracer: tracer, logger: logger, @@ -714,6 +718,9 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { } func (h *sentPacketHandler) ECNMode() protocol.ECN { + if !h.enableECN { + return protocol.ECNUnsupported + } // TODO: implement ECN logic return protocol.ECNNon } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 614b7c85..a1e1f0d9 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -44,7 +44,7 @@ var _ = Describe("SentPacketHandler", func() { JustBeforeEach(func() { lostPackets = nil rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger) streamFrame = wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -984,7 +984,7 @@ var _ = Describe("SentPacketHandler", func() { Context("amplification limit, for the server, with validated address", func() { JustBeforeEach(func() { rttStats := utils.NewRTTStats() - handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger) + handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, false, perspective, nil, utils.DefaultLogger) }) It("do not limits the window", func() { diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index b8104882..d056cb9d 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -37,14 +37,48 @@ func (t PacketType) String() string { type ECN uint8 const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 + ECNUnsupported ECN = iota + ECNNon // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 ) +func ParseECNHeaderBits(bits byte) ECN { + switch bits { + case 0: + return ECNNon + case 0b00000010: + return ECT0 + case 0b00000001: + return ECT1 + case 0b00000011: + return ECNCE + default: + panic("invalid ECN bits") + } +} + +func (e ECN) ToHeaderBits() byte { + //nolint:exhaustive // There are only 4 values. + switch e { + case ECNNon: + return 0 + case ECT0: + return 0b00000010 + case ECT1: + return 0b00000001 + case ECNCE: + return 0b00000011 + default: + panic("ECN unsupported") + } +} + func (e ECN) String() string { switch e { + case ECNUnsupported: + return "ECN unsupported" case ECNNon: return "Not-ECT" case ECT1: diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index e672d31e..22359e6a 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -17,13 +17,22 @@ var _ = Describe("Protocol", func() { }) It("converts ECN bits from the IP header wire to the correct types", func() { - Expect(ECN(0)).To(Equal(ECNNon)) - Expect(ECN(0b00000010)).To(Equal(ECT0)) - Expect(ECN(0b00000001)).To(Equal(ECT1)) - Expect(ECN(0b00000011)).To(Equal(ECNCE)) + Expect(ParseECNHeaderBits(0)).To(Equal(ECNNon)) + Expect(ParseECNHeaderBits(0b00000010)).To(Equal(ECT0)) + Expect(ParseECNHeaderBits(0b00000001)).To(Equal(ECT1)) + Expect(ParseECNHeaderBits(0b00000011)).To(Equal(ECNCE)) + Expect(func() { ParseECNHeaderBits(0b1010101) }).To(Panic()) + }) + + It("converts to IP header bits", func() { + for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} { + Expect(ParseECNHeaderBits(v.ToHeaderBits())).To(Equal(v)) + } + Expect(func() { ECN(42).ToHeaderBits() }).To(Panic()) }) It("has a string representation for ECN", func() { + Expect(ECNUnsupported.String()).To(Equal("ECN unsupported")) Expect(ECNNon.String()).To(Equal("Not-ECT")) Expect(ECT0.String()).To(Equal("ECT(0)")) Expect(ECT1.String()).To(Equal("ECT(1)")) diff --git a/packet_handler_map.go b/packet_handler_map.go index d7d92377..ba30fb23 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -21,6 +21,8 @@ type connCapabilities struct { DF bool // GSO (Generic Segmentation Offload) supported GSO bool + // ECN (Explicit Congestion Notifications) supported + ECN bool } // rawConn is a connection that allow reading of a receivedPackeh. diff --git a/server.go b/server.go index 681e8956..d41c6465 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index 4d7a6577..71cc4607 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,10 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) { if gsoSize != 0 { panic("cannot use GSO with a basicConn") } + if ecn != protocol.ECNUnsupported { + panic("cannot use ECN with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 7b6052ba..b9ddf3c1 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -141,6 +141,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { cap: connCapabilities{ DF: supportsDF, GSO: isGSOSupported(rawConn), + ECN: true, }, } for i := 0; i < batchSize; i++ { @@ -189,7 +190,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case ipv4PKTINFO: ip, ifIndex, ok := parseIPv4PktInfo(body) if ok { @@ -206,7 +207,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case unix.IPV6_PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ @@ -237,11 +238,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso } oob = appendUDPSegmentSizeMsg(oob, gsoSize) } - if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { - if remoteUDPAddr.IP.To4() != nil { - oob = appendIPv4ECNMsg(oob, ecn) - } else { - oob = appendIPv6ECNMsg(oob, ecn) + if ecn != protocol.ECNUnsupported { + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } } } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) @@ -298,7 +301,7 @@ func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) - b[offset] = uint8(val) + b[offset] = val.ToHeaderBits() return b } @@ -313,6 +316,6 @@ func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { // UnixRights uses the private `data` method, but I *think* this achieves the same goal. offset := startLen + unix.CmsgSpace(0) - b[offset] = uint8(val) + b[offset] = val.ToHeaderBits() return b } diff --git a/transport.go b/transport.go index b841222e..c021be77 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } }