only add an ECN control message if ECN is supported

This commit is contained in:
Marten Seemann 2023-08-31 13:06:37 +07:00
parent 5dd6d91c11
commit b73a4de7ea
12 changed files with 90 additions and 29 deletions

View file

@ -278,6 +278,7 @@ var newConnection = func(
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
@ -385,7 +386,8 @@ var newClientConnection = func(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
false, /* has no effect */
false, // has no effect
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,

View file

@ -14,10 +14,11 @@ func NewAckHandler(
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger)
}

View file

@ -62,8 +62,8 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro
if ackEliciting {
h.maybeQueueACK(pn, rcvTime, isMissing)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE.
switch ecn {
case protocol.ECNNon:
case protocol.ECT0:
h.ect0++
case protocol.ECT1:

View file

@ -92,6 +92,8 @@ type sentPacketHandler struct {
// The alarm timeout
alarm time.Time
enableECN bool
perspective protocol.Perspective
tracer logging.ConnectionTracer
@ -110,6 +112,7 @@ func newSentPacketHandler(
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
@ -130,6 +133,7 @@ func newSentPacketHandler(
appDataPackets: newPacketNumberSpace(0, true),
rttStats: rttStats,
congestion: congestion,
enableECN: enableECN,
perspective: pers,
tracer: tracer,
logger: logger,
@ -714,6 +718,9 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
}
func (h *sentPacketHandler) ECNMode() protocol.ECN {
if !h.enableECN {
return protocol.ECNUnsupported
}
// TODO: implement ECN logic
return protocol.ECNNon
}

View file

@ -44,7 +44,7 @@ var _ = Describe("SentPacketHandler", func() {
JustBeforeEach(func() {
lostPackets = nil
rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, perspective, nil, utils.DefaultLogger)
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, false, false, perspective, nil, utils.DefaultLogger)
streamFrame = wire.StreamFrame{
StreamID: 5,
Data: []byte{0x13, 0x37},
@ -984,7 +984,7 @@ var _ = Describe("SentPacketHandler", func() {
Context("amplification limit, for the server, with validated address", func() {
JustBeforeEach(func() {
rttStats := utils.NewRTTStats()
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, perspective, nil, utils.DefaultLogger)
handler = newSentPacketHandler(42, protocol.InitialPacketSizeIPv4, rttStats, true, false, perspective, nil, utils.DefaultLogger)
})
It("do not limits the window", func() {

View file

@ -37,14 +37,48 @@ func (t PacketType) String() string {
type ECN uint8
const (
ECNNon ECN = iota // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
ECNUnsupported ECN = iota
ECNNon // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
)
func ParseECNHeaderBits(bits byte) ECN {
switch bits {
case 0:
return ECNNon
case 0b00000010:
return ECT0
case 0b00000001:
return ECT1
case 0b00000011:
return ECNCE
default:
panic("invalid ECN bits")
}
}
func (e ECN) ToHeaderBits() byte {
//nolint:exhaustive // There are only 4 values.
switch e {
case ECNNon:
return 0
case ECT0:
return 0b00000010
case ECT1:
return 0b00000001
case ECNCE:
return 0b00000011
default:
panic("ECN unsupported")
}
}
func (e ECN) String() string {
switch e {
case ECNUnsupported:
return "ECN unsupported"
case ECNNon:
return "Not-ECT"
case ECT1:

View file

@ -17,13 +17,22 @@ var _ = Describe("Protocol", func() {
})
It("converts ECN bits from the IP header wire to the correct types", func() {
Expect(ECN(0)).To(Equal(ECNNon))
Expect(ECN(0b00000010)).To(Equal(ECT0))
Expect(ECN(0b00000001)).To(Equal(ECT1))
Expect(ECN(0b00000011)).To(Equal(ECNCE))
Expect(ParseECNHeaderBits(0)).To(Equal(ECNNon))
Expect(ParseECNHeaderBits(0b00000010)).To(Equal(ECT0))
Expect(ParseECNHeaderBits(0b00000001)).To(Equal(ECT1))
Expect(ParseECNHeaderBits(0b00000011)).To(Equal(ECNCE))
Expect(func() { ParseECNHeaderBits(0b1010101) }).To(Panic())
})
It("converts to IP header bits", func() {
for _, v := range [...]ECN{ECNNon, ECT0, ECT1, ECNCE} {
Expect(ParseECNHeaderBits(v.ToHeaderBits())).To(Equal(v))
}
Expect(func() { ECN(42).ToHeaderBits() }).To(Panic())
})
It("has a string representation for ECN", func() {
Expect(ECNUnsupported.String()).To(Equal("ECN unsupported"))
Expect(ECNNon.String()).To(Equal("Not-ECT"))
Expect(ECT0.String()).To(Equal("ECT(0)"))
Expect(ECT1.String()).To(Equal("ECT(1)"))

View file

@ -21,6 +21,8 @@ type connCapabilities struct {
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
}
// rawConn is a connection that allow reading of a receivedPackeh.

View file

@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon)
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNNon)
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil {
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
}

View file

@ -104,10 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) {
}, nil
}
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, _ protocol.ECN) (n int, err error) {
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) {
if gsoSize != 0 {
panic("cannot use GSO with a basicConn")
}
if ecn != protocol.ECNUnsupported {
panic("cannot use ECN with a basicConn")
}
return c.PacketConn.WriteTo(b, addr)
}

View file

@ -141,6 +141,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
cap: connCapabilities{
DF: supportsDF,
GSO: isGSOSupported(rawConn),
ECN: true,
},
}
for i := 0; i < batchSize; i++ {
@ -189,7 +190,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type {
case msgTypeIPTOS:
p.ecn = protocol.ECN(body[0] & ecnMask)
p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case ipv4PKTINFO:
ip, ifIndex, ok := parseIPv4PktInfo(body)
if ok {
@ -206,7 +207,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type {
case unix.IPV6_TCLASS:
p.ecn = protocol.ECN(body[0] & ecnMask)
p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case unix.IPV6_PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
@ -237,11 +238,13 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso
}
oob = appendUDPSegmentSizeMsg(oob, gsoSize)
}
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
if ecn != protocol.ECNUnsupported {
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
}
}
}
n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
@ -298,7 +301,7 @@ func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte {
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = uint8(val)
b[offset] = val.ToHeaderBits()
return b
}
@ -313,6 +316,6 @@ func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte {
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = uint8(val)
b[offset] = val.ToHeaderBits()
return b
}

View file

@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
return 0, err
}
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNNon)
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
}
func (t *Transport) enqueueClosePacket(p closePacket) {
@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNNon)
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
case p := <-t.statelessResetQueue:
t.sendStatelessReset(p)
}
@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNNon); err != nil {
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
}
}