From 1793988a6db428084a123749bef8e9e55c371731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 22 Oct 2024 15:33:49 +0800 Subject: [PATCH] Migrate clashtcpip to gVisor tcpip copied --- internal/clashtcpip/icmp.go | 40 - internal/clashtcpip/icmpv6.go | 172 --- internal/clashtcpip/ip.go | 209 --- internal/clashtcpip/ipv6.go | 141 -- internal/clashtcpip/tcp.go | 94 -- internal/clashtcpip/tcpip.go | 24 - internal/clashtcpip/tcpip_amd64.go | 26 - internal/clashtcpip/tcpip_amd64.s | 140 -- internal/clashtcpip/tcpip_amd64_test.go | 51 - internal/clashtcpip/tcpip_arm64.go | 24 - internal/clashtcpip/tcpip_arm64.s | 118 -- internal/clashtcpip/tcpip_arm64_test.go | 51 - internal/clashtcpip/tcpip_compat.go | 14 - internal/clashtcpip/tcpip_compat_test.go | 26 - internal/clashtcpip/udp.go | 59 - internal/gtcpip/README.md | 4 + internal/gtcpip/checksum/checksum.go | 68 + internal/gtcpip/checksum/checksum_unsafe.go | 182 +++ internal/gtcpip/errors.go | 46 + internal/gtcpip/header/checksum.go | 107 ++ internal/gtcpip/header/eth.go | 192 +++ internal/gtcpip/header/icmpv4.go | 228 ++++ internal/gtcpip/header/icmpv6.go | 304 +++++ internal/gtcpip/header/ipv4.go | 1205 +++++++++++++++++ internal/gtcpip/header/ipv6.go | 578 ++++++++ .../gtcpip/header/ipv6_extension_headers.go | 955 +++++++++++++ internal/gtcpip/header/ipv6_fragment.go | 158 +++ internal/gtcpip/header/ndp_neighbor_advert.go | 110 ++ .../gtcpip/header/ndp_neighbor_solicit.go | 52 + internal/gtcpip/header/ndp_options.go | 1073 +++++++++++++++ internal/gtcpip/header/ndp_router_advert.go | 204 +++ internal/gtcpip/header/ndp_router_solicit.go | 36 + .../header/ndpoptionidentifier_string.go | 58 + internal/gtcpip/header/netip.go | 35 + internal/gtcpip/header/tcp.go | 727 ++++++++++ internal/gtcpip/header/udp.go | 195 +++ internal/gtcpip/seqnum/seqnum.go | 62 + internal/gtcpip/tcpip.go | 573 ++++++++ network_name.go | 21 +- stack_gvisor.go | 10 +- stack_mixed.go | 51 +- stack_system.go | 283 ++-- tun_linux_offload.go | 4 +- 43 files changed, 7370 insertions(+), 1340 deletions(-) delete mode 100644 internal/clashtcpip/icmp.go delete mode 100644 internal/clashtcpip/icmpv6.go delete mode 100644 internal/clashtcpip/ip.go delete mode 100644 internal/clashtcpip/ipv6.go delete mode 100644 internal/clashtcpip/tcp.go delete mode 100644 internal/clashtcpip/tcpip.go delete mode 100644 internal/clashtcpip/tcpip_amd64.go delete mode 100644 internal/clashtcpip/tcpip_amd64.s delete mode 100644 internal/clashtcpip/tcpip_amd64_test.go delete mode 100644 internal/clashtcpip/tcpip_arm64.go delete mode 100644 internal/clashtcpip/tcpip_arm64.s delete mode 100644 internal/clashtcpip/tcpip_arm64_test.go delete mode 100644 internal/clashtcpip/tcpip_compat.go delete mode 100644 internal/clashtcpip/tcpip_compat_test.go delete mode 100644 internal/clashtcpip/udp.go create mode 100644 internal/gtcpip/README.md create mode 100644 internal/gtcpip/checksum/checksum.go create mode 100644 internal/gtcpip/checksum/checksum_unsafe.go create mode 100644 internal/gtcpip/errors.go create mode 100644 internal/gtcpip/header/checksum.go create mode 100644 internal/gtcpip/header/eth.go create mode 100644 internal/gtcpip/header/icmpv4.go create mode 100644 internal/gtcpip/header/icmpv6.go create mode 100644 internal/gtcpip/header/ipv4.go create mode 100644 internal/gtcpip/header/ipv6.go create mode 100644 internal/gtcpip/header/ipv6_extension_headers.go create mode 100644 internal/gtcpip/header/ipv6_fragment.go create mode 100644 internal/gtcpip/header/ndp_neighbor_advert.go create mode 100644 internal/gtcpip/header/ndp_neighbor_solicit.go create mode 100644 internal/gtcpip/header/ndp_options.go create mode 100644 internal/gtcpip/header/ndp_router_advert.go create mode 100644 internal/gtcpip/header/ndp_router_solicit.go create mode 100644 internal/gtcpip/header/ndpoptionidentifier_string.go create mode 100644 internal/gtcpip/header/netip.go create mode 100644 internal/gtcpip/header/tcp.go create mode 100644 internal/gtcpip/header/udp.go create mode 100644 internal/gtcpip/seqnum/seqnum.go create mode 100644 internal/gtcpip/tcpip.go diff --git a/internal/clashtcpip/icmp.go b/internal/clashtcpip/icmp.go deleted file mode 100644 index 0050fd5..0000000 --- a/internal/clashtcpip/icmp.go +++ /dev/null @@ -1,40 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" -) - -type ICMPType = byte - -const ( - ICMPTypePingRequest byte = 0x8 - ICMPTypePingResponse byte = 0x0 -) - -type ICMPPacket []byte - -func (p ICMPPacket) Type() ICMPType { - return p[0] -} - -func (p ICMPPacket) SetType(v ICMPType) { - p[0] = v -} - -func (p ICMPPacket) Code() byte { - return p[1] -} - -func (p ICMPPacket) Checksum() uint16 { - return binary.BigEndian.Uint16(p[2:]) -} - -func (p ICMPPacket) SetChecksum(sum [2]byte) { - p[2] = sum[0] - p[3] = sum[1] -} - -func (p ICMPPacket) ResetChecksum() { - p.SetChecksum(zeroChecksum) - p.SetChecksum(Checksum(0, p)) -} diff --git a/internal/clashtcpip/icmpv6.go b/internal/clashtcpip/icmpv6.go deleted file mode 100644 index 67f92d1..0000000 --- a/internal/clashtcpip/icmpv6.go +++ /dev/null @@ -1,172 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" -) - -type ICMPv6Packet []byte - -const ( - ICMPv6HeaderSize = 4 - - ICMPv6MinimumSize = 8 - - ICMPv6PayloadOffset = 8 - - ICMPv6EchoMinimumSize = 8 - - ICMPv6ErrorHeaderSize = 8 - - ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize - - ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize - - ICMPv6ChecksumOffset = 2 - - icmpv6PointerOffset = 4 - - icmpv6MTUOffset = 4 - - icmpv6IdentOffset = 4 - - icmpv6SequenceOffset = 6 - - NDPHopLimit = 255 -) - -type ICMPv6Type byte - -const ( - ICMPv6DstUnreachable ICMPv6Type = 1 - ICMPv6PacketTooBig ICMPv6Type = 2 - ICMPv6TimeExceeded ICMPv6Type = 3 - ICMPv6ParamProblem ICMPv6Type = 4 - ICMPv6EchoRequest ICMPv6Type = 128 - ICMPv6EchoReply ICMPv6Type = 129 - - ICMPv6RouterSolicit ICMPv6Type = 133 - ICMPv6RouterAdvert ICMPv6Type = 134 - ICMPv6NeighborSolicit ICMPv6Type = 135 - ICMPv6NeighborAdvert ICMPv6Type = 136 - ICMPv6RedirectMsg ICMPv6Type = 137 - - ICMPv6MulticastListenerQuery ICMPv6Type = 130 - ICMPv6MulticastListenerReport ICMPv6Type = 131 - ICMPv6MulticastListenerDone ICMPv6Type = 132 -) - -func (typ ICMPv6Type) IsErrorType() bool { - return typ&0x80 == 0 -} - -type ICMPv6Code byte - -const ( - ICMPv6NetworkUnreachable ICMPv6Code = 0 - ICMPv6Prohibited ICMPv6Code = 1 - ICMPv6BeyondScope ICMPv6Code = 2 - ICMPv6AddressUnreachable ICMPv6Code = 3 - ICMPv6PortUnreachable ICMPv6Code = 4 - ICMPv6Policy ICMPv6Code = 5 - ICMPv6RejectRoute ICMPv6Code = 6 -) - -const ( - ICMPv6HopLimitExceeded ICMPv6Code = 0 - ICMPv6ReassemblyTimeout ICMPv6Code = 1 -) - -const ( - ICMPv6ErroneousHeader ICMPv6Code = 0 - - ICMPv6UnknownHeader ICMPv6Code = 1 - - ICMPv6UnknownOption ICMPv6Code = 2 -) - -const ICMPv6UnusedCode ICMPv6Code = 0 - -func (b ICMPv6Packet) Type() ICMPv6Type { - return ICMPv6Type(b[0]) -} - -func (b ICMPv6Packet) SetType(t ICMPv6Type) { - b[0] = byte(t) -} - -func (b ICMPv6Packet) Code() ICMPv6Code { - return ICMPv6Code(b[1]) -} - -func (b ICMPv6Packet) SetCode(c ICMPv6Code) { - b[1] = byte(c) -} - -func (b ICMPv6Packet) TypeSpecific() uint32 { - return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) -} - -func (b ICMPv6Packet) SetTypeSpecific(val uint32) { - binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) -} - -func (b ICMPv6Packet) Checksum() uint16 { - return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:]) -} - -func (b ICMPv6Packet) SetChecksum(sum [2]byte) { - _ = b[ICMPv6ChecksumOffset+1] - b[ICMPv6ChecksumOffset] = sum[0] - b[ICMPv6ChecksumOffset+1] = sum[1] -} - -func (ICMPv6Packet) SourcePort() uint16 { - return 0 -} - -func (ICMPv6Packet) DestinationPort() uint16 { - return 0 -} - -func (ICMPv6Packet) SetSourcePort(uint16) { -} - -func (ICMPv6Packet) SetDestinationPort(uint16) { -} - -func (b ICMPv6Packet) MTU() uint32 { - return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) -} - -func (b ICMPv6Packet) SetMTU(mtu uint32) { - binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) -} - -func (b ICMPv6Packet) Ident() uint16 { - return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) -} - -func (b ICMPv6Packet) SetIdent(ident uint16) { - binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) -} - -func (b ICMPv6Packet) Sequence() uint16 { - return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) -} - -func (b ICMPv6Packet) SetSequence(sequence uint16) { - binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) -} - -func (b ICMPv6Packet) MessageBody() []byte { - return b[ICMPv6HeaderSize:] -} - -func (b ICMPv6Packet) Payload() []byte { - return b[ICMPv6PayloadOffset:] -} - -func (b ICMPv6Packet) ResetChecksum(psum uint32) { - b.SetChecksum(zeroChecksum) - b.SetChecksum(Checksum(psum, b)) -} diff --git a/internal/clashtcpip/ip.go b/internal/clashtcpip/ip.go deleted file mode 100644 index 598656f..0000000 --- a/internal/clashtcpip/ip.go +++ /dev/null @@ -1,209 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" - "errors" - "net/netip" -) - -type IPProtocol = byte - -type IP interface { - Payload() []byte - SourceIP() netip.Addr - DestinationIP() netip.Addr - SetSourceIP(ip netip.Addr) - SetDestinationIP(ip netip.Addr) - Protocol() IPProtocol - DecTimeToLive() - ResetChecksum() - PseudoSum() uint32 -} - -// IPProtocol type -const ( - ICMP IPProtocol = 0x01 - TCP IPProtocol = 0x06 - UDP IPProtocol = 0x11 - ICMPv6 IPProtocol = 0x3a -) - -const ( - FlagDontFragment = 1 << 1 - FlagMoreFragment = 1 << 2 -) - -const ( - IPv4HeaderSize = 20 - - IPv4Version = 4 - - IPv4OptionsOffset = 20 - IPv4PacketMinLength = IPv4OptionsOffset -) - -var ( - ErrInvalidLength = errors.New("invalid packet length") - ErrInvalidIPVersion = errors.New("invalid ip version") - ErrInvalidChecksum = errors.New("invalid checksum") -) - -type IPv4Packet []byte - -func (p IPv4Packet) TotalLen() uint16 { - return binary.BigEndian.Uint16(p[2:]) -} - -func (p IPv4Packet) SetTotalLength(length uint16) { - binary.BigEndian.PutUint16(p[2:], length) -} - -func (p IPv4Packet) HeaderLen() uint16 { - return uint16(p[0]&0xf) * 4 -} - -func (p IPv4Packet) SetHeaderLen(length uint16) { - p[0] &= 0xF0 - p[0] |= byte(length / 4) -} - -func (p IPv4Packet) TypeOfService() byte { - return p[1] -} - -func (p IPv4Packet) SetTypeOfService(tos byte) { - p[1] = tos -} - -func (p IPv4Packet) Identification() uint16 { - return binary.BigEndian.Uint16(p[4:]) -} - -func (p IPv4Packet) SetIdentification(id uint16) { - binary.BigEndian.PutUint16(p[4:], id) -} - -func (p IPv4Packet) FragmentOffset() uint16 { - return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8 -} - -func (p IPv4Packet) SetFragmentOffset(offset uint32) { - flags := p.Flags() - binary.BigEndian.PutUint16(p[6:], uint16(offset/8)) - p.SetFlags(flags) -} - -func (p IPv4Packet) DataLen() uint16 { - return p.TotalLen() - p.HeaderLen() -} - -func (p IPv4Packet) Payload() []byte { - return p[p.HeaderLen():p.TotalLen()] -} - -func (p IPv4Packet) Protocol() IPProtocol { - return p[9] -} - -func (p IPv4Packet) SetProtocol(protocol IPProtocol) { - p[9] = protocol -} - -func (p IPv4Packet) Flags() byte { - return p[6] >> 5 -} - -func (p IPv4Packet) SetFlags(flags byte) { - p[6] &= 0x1F - p[6] |= flags << 5 -} - -func (p IPv4Packet) SourceIP() netip.Addr { - return netip.AddrFrom4([4]byte{p[12], p[13], p[14], p[15]}) -} - -func (p IPv4Packet) SetSourceIP(ip netip.Addr) { - if ip.Is4() { - copy(p[12:16], ip.AsSlice()) - } -} - -func (p IPv4Packet) DestinationIP() netip.Addr { - return netip.AddrFrom4([4]byte{p[16], p[17], p[18], p[19]}) -} - -func (p IPv4Packet) SetDestinationIP(ip netip.Addr) { - if ip.Is4() { - copy(p[16:20], ip.AsSlice()) - } -} - -func (p IPv4Packet) Checksum() uint16 { - return binary.BigEndian.Uint16(p[10:]) -} - -func (p IPv4Packet) SetChecksum(sum [2]byte) { - p[10] = sum[0] - p[11] = sum[1] -} - -func (p IPv4Packet) TimeToLive() uint8 { - return p[8] -} - -func (p IPv4Packet) SetTimeToLive(ttl uint8) { - p[8] = ttl -} - -func (p IPv4Packet) DecTimeToLive() { - p[8] = p[8] - uint8(1) -} - -func (p IPv4Packet) ResetChecksum() { - p.SetChecksum(zeroChecksum) - p.SetChecksum(Checksum(0, p[:p.HeaderLen()])) -} - -// PseudoSum for tcp checksum -func (p IPv4Packet) PseudoSum() uint32 { - sum := Sum(p[12:20]) - sum += uint32(p.Protocol()) - sum += uint32(p.DataLen()) - return sum -} - -func (p IPv4Packet) Valid() bool { - return len(p) >= IPv4HeaderSize && p.TotalLen() >= p.HeaderLen() && uint16(len(p)) >= p.TotalLen() -} - -func (p IPv4Packet) Verify() error { - if len(p) < IPv4PacketMinLength { - return ErrInvalidLength - } - - checksum := []byte{p[10], p[11]} - headerLength := uint16(p[0]&0xF) * 4 - packetLength := binary.BigEndian.Uint16(p[2:]) - - if p[0]>>4 != 4 { - return ErrInvalidIPVersion - } - - if uint16(len(p)) < packetLength || packetLength < headerLength { - return ErrInvalidLength - } - - p[10] = 0 - p[11] = 0 - defer copy(p[10:12], checksum) - - answer := Checksum(0, p[:headerLength]) - - if answer[0] != checksum[0] || answer[1] != checksum[1] { - return ErrInvalidChecksum - } - - return nil -} - -var _ IP = (*IPv4Packet)(nil) diff --git a/internal/clashtcpip/ipv6.go b/internal/clashtcpip/ipv6.go deleted file mode 100644 index 20147e5..0000000 --- a/internal/clashtcpip/ipv6.go +++ /dev/null @@ -1,141 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" - "net/netip" -) - -const ( - versTCFL = 0 - - IPv6PayloadLenOffset = 4 - - IPv6NextHeaderOffset = 6 - hopLimit = 7 - v6SrcAddr = 8 - v6DstAddr = v6SrcAddr + IPv6AddressSize - - IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize -) - -const ( - versIHL = 0 - tos = 1 - ipVersionShift = 4 - ipIHLMask = 0x0f - IPv4IHLStride = 4 -) - -type IPv6Packet []byte - -const ( - IPv6MinimumSize = IPv6FixedHeaderSize - - IPv6AddressSize = 16 - - IPv6Version = 6 - - IPv6MinimumMTU = 1280 -) - -func (b IPv6Packet) PayloadLength() uint16 { - return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) -} - -func (b IPv6Packet) HopLimit() uint8 { - return b[hopLimit] -} - -func (b IPv6Packet) NextHeader() byte { - return b[IPv6NextHeaderOffset] -} - -func (b IPv6Packet) Protocol() IPProtocol { - return b.NextHeader() -} - -func (b IPv6Packet) Payload() []byte { - return b[IPv6MinimumSize:][:b.PayloadLength()] -} - -func (b IPv6Packet) SourceIP() netip.Addr { - addr, _ := netip.AddrFromSlice(b[v6SrcAddr:][:IPv6AddressSize]) - return addr -} - -func (b IPv6Packet) DestinationIP() netip.Addr { - addr, _ := netip.AddrFromSlice(b[v6DstAddr:][:IPv6AddressSize]) - return addr -} - -func (IPv6Packet) Checksum() uint16 { - return 0 -} - -func (b IPv6Packet) TOS() (uint8, uint32) { - v := binary.BigEndian.Uint32(b[versTCFL:]) - return uint8(v >> 20), v & 0xfffff -} - -func (b IPv6Packet) SetTOS(t uint8, l uint32) { - vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) - binary.BigEndian.PutUint32(b[versTCFL:], vtf) -} - -func (b IPv6Packet) SetPayloadLength(payloadLength uint16) { - binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) -} - -func (b IPv6Packet) SetSourceIP(addr netip.Addr) { - if addr.Is6() { - copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice()) - } -} - -func (b IPv6Packet) SetDestinationIP(addr netip.Addr) { - if addr.Is6() { - copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice()) - } -} - -func (b IPv6Packet) SetHopLimit(v uint8) { - b[hopLimit] = v -} - -func (b IPv6Packet) SetNextHeader(v byte) { - b[IPv6NextHeaderOffset] = v -} - -func (b IPv6Packet) SetProtocol(p IPProtocol) { - b.SetNextHeader(p) -} - -func (b IPv6Packet) DecTimeToLive() { - b[hopLimit] = b[hopLimit] - uint8(1) -} - -func (IPv6Packet) SetChecksum(uint16) { -} - -func (IPv6Packet) ResetChecksum() { -} - -func (b IPv6Packet) PseudoSum() uint32 { - sum := Sum(b[v6SrcAddr:IPv6FixedHeaderSize]) - sum += uint32(b.Protocol()) - sum += uint32(b.PayloadLength()) - return sum -} - -func (b IPv6Packet) Valid() bool { - return len(b) >= IPv6MinimumSize && len(b) >= int(b.PayloadLength())+IPv6MinimumSize -} - -func IPVersion(b []byte) int { - if len(b) < versIHL+1 { - return -1 - } - return int(b[versIHL] >> ipVersionShift) -} - -var _ IP = (*IPv6Packet)(nil) diff --git a/internal/clashtcpip/tcp.go b/internal/clashtcpip/tcp.go deleted file mode 100644 index ee7a894..0000000 --- a/internal/clashtcpip/tcp.go +++ /dev/null @@ -1,94 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" - "net" -) - -const ( - TCPFin uint16 = 1 << 0 - TCPSyn uint16 = 1 << 1 - TCPRst uint16 = 1 << 2 - TCPPuh uint16 = 1 << 3 - TCPAck uint16 = 1 << 4 - TCPUrg uint16 = 1 << 5 - TCPEce uint16 = 1 << 6 - TCPEwr uint16 = 1 << 7 - TCPNs uint16 = 1 << 8 -) - -const TCPHeaderSize = 20 - -type TCPPacket []byte - -func (p TCPPacket) SourcePort() uint16 { - return binary.BigEndian.Uint16(p) -} - -func (p TCPPacket) SetSourcePort(port uint16) { - binary.BigEndian.PutUint16(p, port) -} - -func (p TCPPacket) DestinationPort() uint16 { - return binary.BigEndian.Uint16(p[2:]) -} - -func (p TCPPacket) SetDestinationPort(port uint16) { - binary.BigEndian.PutUint16(p[2:], port) -} - -func (p TCPPacket) Flags() uint16 { - return uint16(p[13] | (p[12] & 0x1)) -} - -func (p TCPPacket) Checksum() uint16 { - return binary.BigEndian.Uint16(p[16:]) -} - -func (p TCPPacket) SetChecksum(sum [2]byte) { - p[16] = sum[0] - p[17] = sum[1] -} - -func (p TCPPacket) OffloadChecksum() { - p.SetChecksum(zeroChecksum) -} - -func (p TCPPacket) ResetChecksum(psum uint32) { - p.SetChecksum(zeroChecksum) - p.SetChecksum(Checksum(psum, p)) -} - -func (p TCPPacket) Valid() bool { - return len(p) >= TCPHeaderSize -} - -func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error { - var checksum [2]byte - checksum[0] = p[16] - checksum[1] = p[17] - - // reset checksum - p[16] = 0 - p[17] = 0 - - // restore checksum - defer func() { - p[16] = checksum[0] - p[17] = checksum[1] - }() - - // check checksum - s := uint32(0) - s += Sum(sourceAddress) - s += Sum(targetAddress) - s += uint32(TCP) - s += uint32(len(p)) - - check := Checksum(s, p) - if checksum[0] != check[0] || checksum[1] != check[1] { - return ErrInvalidChecksum - } - - return nil -} diff --git a/internal/clashtcpip/tcpip.go b/internal/clashtcpip/tcpip.go deleted file mode 100644 index 2994637..0000000 --- a/internal/clashtcpip/tcpip.go +++ /dev/null @@ -1,24 +0,0 @@ -package clashtcpip - -var zeroChecksum = [2]byte{0x00, 0x00} - -var SumFnc = SumCompat - -func Sum(b []byte) uint32 { - return SumFnc(b) -} - -// Checksum for Internet Protocol family headers -func Checksum(sum uint32, b []byte) (answer [2]byte) { - sum += Sum(b) - sum = (sum >> 16) + (sum & 0xffff) - sum += sum >> 16 - sum = ^sum - answer[0] = byte(sum >> 8) - answer[1] = byte(sum) - return -} - -func SetIPv4(packet []byte) { - packet[0] = (packet[0] & 0x0f) | (4 << 4) -} diff --git a/internal/clashtcpip/tcpip_amd64.go b/internal/clashtcpip/tcpip_amd64.go deleted file mode 100644 index 711a85c..0000000 --- a/internal/clashtcpip/tcpip_amd64.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !noasm - -package clashtcpip - -import ( - "unsafe" - - "golang.org/x/sys/cpu" -) - -//go:noescape -func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr - -func SumAVX2(data []byte) uint32 { - if len(data) == 0 { - return 0 - } - - return uint32(sumAsmAvx2(unsafe.Pointer(&data[0]), uintptr(len(data)))) -} - -func init() { - if cpu.X86.HasAVX2 { - SumFnc = SumAVX2 - } -} diff --git a/internal/clashtcpip/tcpip_amd64.s b/internal/clashtcpip/tcpip_amd64.s deleted file mode 100644 index 100820b..0000000 --- a/internal/clashtcpip/tcpip_amd64.s +++ /dev/null @@ -1,140 +0,0 @@ -#include "textflag.h" - -DATA endian_swap_mask<>+0(SB)/8, $0x607040502030001 -DATA endian_swap_mask<>+8(SB)/8, $0xE0F0C0D0A0B0809 -DATA endian_swap_mask<>+16(SB)/8, $0x607040502030001 -DATA endian_swap_mask<>+24(SB)/8, $0xE0F0C0D0A0B0809 -GLOBL endian_swap_mask<>(SB), RODATA, $32 - -// func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr -// -// args (8 bytes aligned): -// data unsafe.Pointer - 8 bytes - 0 offset -// length uintptr - 8 bytes - 8 offset -// result uintptr - 8 bytes - 16 offset -#define PDATA AX -#define LENGTH CX -#define RESULT BX -TEXT ·sumAsmAvx2(SB),NOSPLIT,$0-24 - MOVQ data+0(FP), PDATA - MOVQ length+8(FP), LENGTH - XORQ RESULT, RESULT - -#define VSUM Y0 -#define ENDIAN_SWAP_MASK Y1 -BEGIN: - VMOVDQU endian_swap_mask<>(SB), ENDIAN_SWAP_MASK - VPXOR VSUM, VSUM, VSUM - -#define LOADED_0 Y2 -#define LOADED_1 Y3 -#define LOADED_2 Y4 -#define LOADED_3 Y5 -BATCH_64: - CMPQ LENGTH, $64 - JB BATCH_32 - VPMOVZXWD (PDATA), LOADED_0 - VPMOVZXWD 16(PDATA), LOADED_1 - VPMOVZXWD 32(PDATA), LOADED_2 - VPMOVZXWD 48(PDATA), LOADED_3 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_2, LOADED_2 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_3, LOADED_3 - VPADDD LOADED_0, VSUM, VSUM - VPADDD LOADED_1, VSUM, VSUM - VPADDD LOADED_2, VSUM, VSUM - VPADDD LOADED_3, VSUM, VSUM - ADDQ $-64, LENGTH - ADDQ $64, PDATA - JMP BATCH_64 -#undef LOADED_0 -#undef LOADED_1 -#undef LOADED_2 -#undef LOADED_3 - -#define LOADED_0 Y2 -#define LOADED_1 Y3 -BATCH_32: - CMPQ LENGTH, $32 - JB BATCH_16 - VPMOVZXWD (PDATA), LOADED_0 - VPMOVZXWD 16(PDATA), LOADED_1 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0 - VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1 - VPADDD LOADED_0, VSUM, VSUM - VPADDD LOADED_1, VSUM, VSUM - ADDQ $-32, LENGTH - ADDQ $32, PDATA - JMP BATCH_32 -#undef LOADED_0 -#undef LOADED_1 - -#define LOADED Y2 -BATCH_16: - CMPQ LENGTH, $16 - JB COLLECT - VPMOVZXWD (PDATA), LOADED - VPSHUFB ENDIAN_SWAP_MASK, LOADED, LOADED - VPADDD LOADED, VSUM, VSUM - ADDQ $-16, LENGTH - ADDQ $16, PDATA - JMP BATCH_16 -#undef LOADED - -#define EXTRACTED Y2 -#define EXTRACTED_128 X2 -#define TEMP_64 DX -COLLECT: - VEXTRACTI128 $0, VSUM, EXTRACTED_128 - VPEXTRD $0, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $1, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $2, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $3, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VEXTRACTI128 $1, VSUM, EXTRACTED_128 - VPEXTRD $0, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $1, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $2, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT - VPEXTRD $3, EXTRACTED_128, TEMP_64 - ADDL TEMP_64, RESULT -#undef EXTRACTED -#undef EXTRACTED_128 -#undef TEMP_64 - -#define TEMP DX -#define TEMP2 SI -BATCH_2: - CMPQ LENGTH, $2 - JB BATCH_1 - XORQ TEMP, TEMP - MOVW (PDATA), TEMP - MOVQ TEMP, TEMP2 - SHRW $8, TEMP2 - SHLW $8, TEMP - ORW TEMP2, TEMP - ADDL TEMP, RESULT - ADDQ $-2, LENGTH - ADDQ $2, PDATA - JMP BATCH_2 -#undef TEMP - -#define TEMP DX -BATCH_1: - CMPQ LENGTH, $0 - JZ RETURN - XORQ TEMP, TEMP - MOVB (PDATA), TEMP - SHLW $8, TEMP - ADDL TEMP, RESULT -#undef TEMP - -RETURN: - MOVQ RESULT, result+16(FP) - RET diff --git a/internal/clashtcpip/tcpip_amd64_test.go b/internal/clashtcpip/tcpip_amd64_test.go deleted file mode 100644 index e02a1b9..0000000 --- a/internal/clashtcpip/tcpip_amd64_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package clashtcpip - -import ( - "crypto/rand" - "testing" - - "golang.org/x/sys/cpu" -) - -func Test_SumAVX2(t *testing.T) { - if !cpu.X86.HasAVX2 { - t.Skipf("AVX2 unavailable") - } - - bytes := make([]byte, chunkSize) - - for size := 0; size <= chunkSize; size++ { - for count := 0; count < chunkCount; count++ { - _, err := rand.Reader.Read(bytes[:size]) - if err != nil { - t.Skipf("Rand read failed: %v", err) - } - - compat := SumCompat(bytes[:size]) - avx := SumAVX2(bytes[:size]) - - if compat != avx { - t.Errorf("Sum of length=%d mismatched", size) - } - } - } -} - -func Benchmark_SumAVX2(b *testing.B) { - if !cpu.X86.HasAVX2 { - b.Skipf("AVX2 unavailable") - } - - bytes := make([]byte, chunkSize) - - _, err := rand.Reader.Read(bytes) - if err != nil { - b.Skipf("Rand read failed: %v", err) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - SumAVX2(bytes) - } -} diff --git a/internal/clashtcpip/tcpip_arm64.go b/internal/clashtcpip/tcpip_arm64.go deleted file mode 100644 index 543803f..0000000 --- a/internal/clashtcpip/tcpip_arm64.go +++ /dev/null @@ -1,24 +0,0 @@ -package clashtcpip - -import ( - "unsafe" - - "golang.org/x/sys/cpu" -) - -//go:noescape -func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr - -func SumNeon(data []byte) uint32 { - if len(data) == 0 { - return 0 - } - - return uint32(sumAsmNeon(unsafe.Pointer(&data[0]), uintptr(len(data)))) -} - -func init() { - if cpu.ARM64.HasASIMD { - SumFnc = SumNeon - } -} diff --git a/internal/clashtcpip/tcpip_arm64.s b/internal/clashtcpip/tcpip_arm64.s deleted file mode 100644 index f6d57cf..0000000 --- a/internal/clashtcpip/tcpip_arm64.s +++ /dev/null @@ -1,118 +0,0 @@ -#include "textflag.h" - -// func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr -// -// args (8 bytes aligned): -// data unsafe.Pointer - 8 bytes - 0 offset -// length uintptr - 8 bytes - 8 offset -// result uintptr - 8 bytes - 16 offset -#define PDATA R0 -#define LENGTH R1 -#define RESULT R2 -#define VSUM V0 -TEXT ·sumAsmNeon(SB),NOSPLIT,$0-24 - MOVD data+0(FP), PDATA - MOVD length+8(FP), LENGTH - MOVD $0, RESULT - VMOVQ $0, $0, VSUM - -#define LOADED_0 V1 -#define LOADED_1 V2 -#define LOADED_2 V3 -#define LOADED_3 V4 -BATCH_32: - CMP $32, LENGTH - BLO BATCH_16 - VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8, LOADED_2.B8, LOADED_3.B8] - VREV16 LOADED_0.B8, LOADED_0.B8 - VREV16 LOADED_1.B8, LOADED_1.B8 - VREV16 LOADED_2.B8, LOADED_2.B8 - VREV16 LOADED_3.B8, LOADED_3.B8 - VUSHLL $0, LOADED_0.H4, LOADED_0.S4 - VUSHLL $0, LOADED_1.H4, LOADED_1.S4 - VUSHLL $0, LOADED_2.H4, LOADED_2.S4 - VUSHLL $0, LOADED_3.H4, LOADED_3.S4 - VADD LOADED_0.S4, VSUM.S4, VSUM.S4 - VADD LOADED_1.S4, VSUM.S4, VSUM.S4 - VADD LOADED_2.S4, VSUM.S4, VSUM.S4 - VADD LOADED_3.S4, VSUM.S4, VSUM.S4 - ADD $-32, LENGTH - ADD $32, PDATA - B BATCH_32 -#undef LOADED_0 -#undef LOADED_1 -#undef LOADED_2 -#undef LOADED_3 - -#define LOADED_0 V1 -#define LOADED_1 V2 -BATCH_16: - CMP $16, LENGTH - BLO BATCH_8 - VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8] - VREV16 LOADED_0.B8, LOADED_0.B8 - VREV16 LOADED_1.B8, LOADED_1.B8 - VUSHLL $0, LOADED_0.H4, LOADED_0.S4 - VUSHLL $0, LOADED_1.H4, LOADED_1.S4 - VADD LOADED_0.S4, VSUM.S4, VSUM.S4 - VADD LOADED_1.S4, VSUM.S4, VSUM.S4 - ADD $-16, LENGTH - ADD $16, PDATA - B BATCH_16 -#undef LOADED_0 -#undef LOADED_1 - -#define LOADED_0 V1 -BATCH_8: - CMP $8, LENGTH - BLO BATCH_2 - VLD1 (PDATA), [LOADED_0.B8] - VREV16 LOADED_0.B8, LOADED_0.B8 - VUSHLL $0, LOADED_0.H4, LOADED_0.S4 - VADD LOADED_0.S4, VSUM.S4, VSUM.S4 - ADD $-8, LENGTH - ADD $8, PDATA - B BATCH_8 -#undef LOADED_0 - -#define LOADED_L R3 -#define LOADED_H R4 -BATCH_2: - CMP $2, LENGTH - BLO BATCH_1 - MOVBU (PDATA), LOADED_H - MOVBU 1(PDATA), LOADED_L - LSL $8, LOADED_H - ORR LOADED_H, LOADED_L, LOADED_L - ADD LOADED_L, RESULT, RESULT - ADD $2, PDATA - ADD $-2, LENGTH - B BATCH_2 -#undef LOADED_H -#undef LOADED_L - -#define LOADED R3 -BATCH_1: - CMP $1, LENGTH - BLO COLLECT - MOVBU (PDATA), LOADED - LSL $8, LOADED - ADD LOADED, RESULT, RESULT - -#define EXTRACTED R3 -COLLECT: - VMOV VSUM.S[0], EXTRACTED - ADD EXTRACTED, RESULT - VMOV VSUM.S[1], EXTRACTED - ADD EXTRACTED, RESULT - VMOV VSUM.S[2], EXTRACTED - ADD EXTRACTED, RESULT - VMOV VSUM.S[3], EXTRACTED - ADD EXTRACTED, RESULT -#undef VSUM -#undef PDATA -#undef LENGTH - -RETURN: - MOVD RESULT, result+16(FP) - RET diff --git a/internal/clashtcpip/tcpip_arm64_test.go b/internal/clashtcpip/tcpip_arm64_test.go deleted file mode 100644 index bfe07a6..0000000 --- a/internal/clashtcpip/tcpip_arm64_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package clashtcpip - -import ( - "crypto/rand" - "testing" - - "golang.org/x/sys/cpu" -) - -func Test_SumNeon(t *testing.T) { - if !cpu.ARM64.HasASIMD { - t.Skipf("Neon unavailable") - } - - bytes := make([]byte, chunkSize) - - for size := 0; size <= chunkSize; size++ { - for count := 0; count < chunkCount; count++ { - _, err := rand.Reader.Read(bytes[:size]) - if err != nil { - t.Skipf("Rand read failed: %v", err) - } - - compat := SumCompat(bytes[:size]) - neon := SumNeon(bytes[:size]) - - if compat != neon { - t.Errorf("Sum of length=%d mismatched", size) - } - } - } -} - -func Benchmark_SumNeon(b *testing.B) { - if !cpu.ARM64.HasASIMD { - b.Skipf("Neon unavailable") - } - - bytes := make([]byte, chunkSize) - - _, err := rand.Reader.Read(bytes) - if err != nil { - b.Skipf("Rand read failed: %v", err) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - SumNeon(bytes) - } -} diff --git a/internal/clashtcpip/tcpip_compat.go b/internal/clashtcpip/tcpip_compat.go deleted file mode 100644 index a72a489..0000000 --- a/internal/clashtcpip/tcpip_compat.go +++ /dev/null @@ -1,14 +0,0 @@ -package clashtcpip - -func SumCompat(b []byte) (sum uint32) { - n := len(b) - if n&1 != 0 { - n-- - sum += uint32(b[n]) << 8 - } - - for i := 0; i < n; i += 2 { - sum += (uint32(b[i]) << 8) | uint32(b[i+1]) - } - return -} diff --git a/internal/clashtcpip/tcpip_compat_test.go b/internal/clashtcpip/tcpip_compat_test.go deleted file mode 100644 index 828d886..0000000 --- a/internal/clashtcpip/tcpip_compat_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package clashtcpip - -import ( - "crypto/rand" - "testing" -) - -const ( - chunkSize = 9000 - chunkCount = 10 -) - -func Benchmark_SumCompat(b *testing.B) { - bytes := make([]byte, chunkSize) - - _, err := rand.Reader.Read(bytes) - if err != nil { - b.Skipf("Rand read failed: %v", err) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - SumCompat(bytes) - } -} diff --git a/internal/clashtcpip/udp.go b/internal/clashtcpip/udp.go deleted file mode 100644 index f576e99..0000000 --- a/internal/clashtcpip/udp.go +++ /dev/null @@ -1,59 +0,0 @@ -package clashtcpip - -import ( - "encoding/binary" -) - -const UDPHeaderSize = 8 - -type UDPPacket []byte - -func (p UDPPacket) Length() uint16 { - return binary.BigEndian.Uint16(p[4:]) -} - -func (p UDPPacket) SetLength(length uint16) { - binary.BigEndian.PutUint16(p[4:], length) -} - -func (p UDPPacket) SourcePort() uint16 { - return binary.BigEndian.Uint16(p) -} - -func (p UDPPacket) SetSourcePort(port uint16) { - binary.BigEndian.PutUint16(p, port) -} - -func (p UDPPacket) DestinationPort() uint16 { - return binary.BigEndian.Uint16(p[2:]) -} - -func (p UDPPacket) SetDestinationPort(port uint16) { - binary.BigEndian.PutUint16(p[2:], port) -} - -func (p UDPPacket) Payload() []byte { - return p[UDPHeaderSize:p.Length()] -} - -func (p UDPPacket) Checksum() uint16 { - return binary.BigEndian.Uint16(p[6:]) -} - -func (p UDPPacket) SetChecksum(sum [2]byte) { - p[6] = sum[0] - p[7] = sum[1] -} - -func (p UDPPacket) OffloadChecksum() { - p.SetChecksum(zeroChecksum) -} - -func (p UDPPacket) ResetChecksum(psum uint32) { - p.SetChecksum(zeroChecksum) - p.SetChecksum(Checksum(psum, p)) -} - -func (p UDPPacket) Valid() bool { - return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length() -} diff --git a/internal/gtcpip/README.md b/internal/gtcpip/README.md new file mode 100644 index 0000000..10d4839 --- /dev/null +++ b/internal/gtcpip/README.md @@ -0,0 +1,4 @@ +# gtcpip + +Minimal tcpip package kanged from gvisor +Version 20241007.0 diff --git a/internal/gtcpip/checksum/checksum.go b/internal/gtcpip/checksum/checksum.go new file mode 100644 index 0000000..5d4e117 --- /dev/null +++ b/internal/gtcpip/checksum/checksum.go @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checksum provides the implementation of the encoding and decoding of +// network protocol headers. +package checksum + +import ( + "encoding/binary" +) + +// Size is the size of a checksum. +// +// The checksum is held in a uint16 which is 2 bytes. +const Size = 2 + +// Put puts the checksum in the provided byte slice. +func Put(b []byte, xsum uint16) { + binary.BigEndian.PutUint16(b, xsum) +} + +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. This function uses an optimized version of the checksum +// algorithm. +// +// The initial checksum must have been computed on an even number of bytes. +func Checksum(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, initial) + return s +} + +// Checksumer calculates checksum defined in RFC 1071. +type Checksumer struct { + sum uint16 + odd bool +} + +// Add adds b to checksum. +func (c *Checksumer) Add(b []byte) { + if len(b) > 0 { + c.sum, c.odd = calculateChecksum(b, c.odd, c.sum) + } +} + +// Checksum returns the latest checksum value. +func (c *Checksumer) Checksum() uint16 { + return c.sum +} + +// Combine combines the two uint16 to form their checksum. This is done +// by adding them and the carry. +// +// Note that checksum a must have been computed on an even number of bytes. +func Combine(a, b uint16) uint16 { + v := uint32(a) + uint32(b) + return uint16(v + v>>16) +} diff --git a/internal/gtcpip/checksum/checksum_unsafe.go b/internal/gtcpip/checksum/checksum_unsafe.go new file mode 100644 index 0000000..66b7ab6 --- /dev/null +++ b/internal/gtcpip/checksum/checksum_unsafe.go @@ -0,0 +1,182 @@ +// Copyright 2023 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checksum + +import ( + "encoding/binary" + "math/bits" + "unsafe" +) + +// Note: odd indicates whether initial is a partial checksum over an odd number +// of bytes. +func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) { + // Use a larger-than-uint16 accumulator to benefit from parallel summation + // as described in RFC 1071 1.2.C. + acc := uint64(initial) + + // Handle an odd number of previously-summed bytes, and get the return + // value for odd. + if odd { + acc += uint64(buf[0]) + buf = buf[1:] + } + odd = len(buf)&1 != 0 + + // Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case + // smaller bufs. + if len(buf) < 8 { + if len(buf) >= 4 { + acc += (uint64(buf[0]) << 8) + uint64(buf[1]) + acc += (uint64(buf[2]) << 8) + uint64(buf[3]) + buf = buf[4:] + } + if len(buf) >= 2 { + acc += (uint64(buf[0]) << 8) + uint64(buf[1]) + buf = buf[2:] + } + if len(buf) >= 1 { + acc += uint64(buf[0]) << 8 + // buf = buf[1:] is skipped because it's unused and nogo will + // complain. + } + return reduce(acc), odd + } + + // On little-endian architectures, multi-byte loads from buf will load + // bytes in the wrong order. Rather than byte-swap after each load (slow), + // we byte-swap the accumulator before summing any bytes and byte-swap it + // back before returning, which still produces the correct result as + // described in RFC 1071 1.2.B "Byte Order Independence". + // + // acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We + // preserve this property by byte-swapping only the lower 32 bits of acc, + // so that additions to acc performed during alignment can't overflow. + acc = uint64(bswapIfLittleEndian32(uint32(acc))) + + // Align &buf[0] to an 8-byte boundary. + bswapped := false + if sliceAddr(buf)&1 != 0 { + // Compute the rest of the partial checksum with bytes swapped, and + // swap back before returning; see the last paragraph of + // RFC 1071 1.2.B. + acc = uint64(bits.ReverseBytes32(uint32(acc))) + bswapped = true + // No `<< 8` here due to the byte swap we just did. + acc += uint64(bswapIfLittleEndian16(uint16(buf[0]))) + buf = buf[1:] + } + if sliceAddr(buf)&2 != 0 { + acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0]))) + buf = buf[2:] + } + if sliceAddr(buf)&4 != 0 { + acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0]))) + buf = buf[4:] + } + + // Sum 64 bytes at a time. Beyond this point, additions to acc may + // overflow, so we have to handle carrying. + for len(buf) >= 64 { + var carry uint64 + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[64:] + } + + // Sum the remaining 0-63 bytes. + if len(buf) >= 32 { + var carry uint64 + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[32:] + } + if len(buf) >= 16 { + var carry uint64 + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[16:] + } + if len(buf) >= 8 { + var carry uint64 + acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[8:] + } + if len(buf) >= 4 { + var carry uint64 + acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[4:] + } + if len(buf) >= 2 { + var carry uint64 + acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0) + acc, _ = bits.Add64(acc, 0, carry) + buf = buf[2:] + } + if len(buf) >= 1 { + // bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8). + var carry uint64 + acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0) + acc, _ = bits.Add64(acc, 0, carry) + // buf = buf[1:] is skipped because it's unused and nogo will complain. + } + + // Reduce the checksum to 16 bits and undo byte swaps before returning. + acc16 := bswapIfLittleEndian16(reduce(acc)) + if bswapped { + acc16 = bits.ReverseBytes16(acc16) + } + return acc16, odd +} + +func reduce(acc uint64) uint16 { + // Ideally we would do: + // return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc) + // for more instruction-level parallelism; however, there is no + // bits.Add16(). + acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe + acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff + acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe + return uint16(acc32>>16 + acc32) // at most 0xffff +} + +func bswapIfLittleEndian32(val uint32) uint32 { + return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:]) +} + +func bswapIfLittleEndian16(val uint16) uint16 { + return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:]) +} + +func bswapIfBigEndian16(val uint16) uint16 { + return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:]) +} + +func sliceAddr(buf []byte) uintptr { + return uintptr(unsafe.Pointer(unsafe.SliceData(buf))) +} diff --git a/internal/gtcpip/errors.go b/internal/gtcpip/errors.go new file mode 100644 index 0000000..43b2eae --- /dev/null +++ b/internal/gtcpip/errors.go @@ -0,0 +1,46 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" +) + +// Error represents an error in the netstack error space. +// +// The error interface is intentionally omitted to avoid loss of type +// information that would occur if these errors were passed as error. +type Error interface { + isError() + + // IgnoreStats indicates whether this error should be included in failure + // counts in tcpip.Stats structs. + IgnoreStats() bool + + fmt.Stringer +} + +// ErrBadAddress indicates a bad address was provided. +// +// +stateify savable +type ErrBadAddress struct{} + +func (*ErrBadAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadAddress) IgnoreStats() bool { + return false +} +func (*ErrBadAddress) String() string { return "bad address" } diff --git a/internal/gtcpip/header/checksum.go b/internal/gtcpip/header/checksum.go new file mode 100644 index 0000000..2c21e6d --- /dev/null +++ b/internal/gtcpip/header/checksum.go @@ -0,0 +1,107 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header + +import ( + "encoding/binary" + "fmt" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" +) + +// PseudoHeaderChecksum calculates the pseudo-header checksum for the given +// destination protocol and network address. Pseudo-headers are needed by +// transport layers when calculating their own checksum. +func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr []byte, dstAddr []byte, totalLen uint16) uint16 { + xsum := checksum.Checksum(srcAddr, 0) + xsum = checksum.Checksum(dstAddr, xsum) + + // Add the length portion of the checksum to the pseudo-checksum. + var tmp [2]byte + binary.BigEndian.PutUint16(tmp[:], totalLen) + xsum = checksum.Checksum(tmp[:], xsum) + + return checksum.Checksum([]byte{0, uint8(protocol)}, xsum) +} + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + if old == new { + return xsum + } + return checksum.Combine(xsum, checksum.Combine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if old.BitLen() != new.BitLen() { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", old.BitLen()/8, new.BitLen()/8)) + } + + if oldBytes := old.BitLen() % 16; oldBytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", oldBytes)) + } + + oldAddr := old.AsSlice() + newAddr := new.AsSlice() + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(oldAddr) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1])) + oldAddr = oldAddr[uint16Bytes:] + newAddr = newAddr[uint16Bytes:] + } + + return xsum +} diff --git a/internal/gtcpip/header/eth.go b/internal/gtcpip/header/eth.go new file mode 100644 index 0000000..9d876ee --- /dev/null +++ b/internal/gtcpip/header/eth.go @@ -0,0 +1,192 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "github.com/sagernet/sing-tun/internal/gtcpip" +) + +const ( + dstMAC = 0 + srcMAC = 6 + ethType = 12 +) + +// EthernetFields contains the fields of an ethernet frame header. It is used to +// describe the fields of a frame that needs to be encoded. +type EthernetFields struct { + // SrcAddr is the "MAC source" field of an ethernet frame header. + SrcAddr tcpip.LinkAddress + + // DstAddr is the "MAC destination" field of an ethernet frame header. + DstAddr tcpip.LinkAddress + + // Type is the "ethertype" field of an ethernet frame header. + Type tcpip.NetworkProtocolNumber +} + +// Ethernet represents an ethernet frame header stored in a byte array. +type Ethernet []byte + +const ( + // EthernetMinimumSize is the minimum size of a valid ethernet frame. + EthernetMinimumSize = 14 + + // EthernetMaximumSize is the maximum size of a valid ethernet frame. + EthernetMaximumSize = 18 + + // EthernetAddressSize is the size, in bytes, of an ethernet address. + EthernetAddressSize = 6 + + // UnspecifiedEthernetAddress is the unspecified ethernet address + // (all bits set to 0). + UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + + // unicastMulticastFlagMask is the mask of the least significant bit in + // the first octet (in network byte order) of an ethernet address that + // determines whether the ethernet address is a unicast or multicast. If + // the masked bit is a 1, then the address is a multicast, unicast + // otherwise. + // + // See the IEEE Std 802-2001 document for more details. Specifically, + // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf: + // "A 48-bit universal address consists of two parts. The first 24 bits + // correspond to the OUI as assigned by the IEEE, expect that the + // assignee may set the LSB of the first octet to 1 for group addresses + // or set it to 0 for individual addresses." + unicastMulticastFlagMask = 1 + + // unicastMulticastFlagByteIdx is the byte that holds the + // unicast/multicast flag. See unicastMulticastFlagMask. + unicastMulticastFlagByteIdx = 0 +) + +const ( + // EthernetProtocolAll is a catch-all for all protocols carried inside + // an ethernet frame. It is mainly used to create packet sockets that + // capture all traffic. + EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003 + + // EthernetProtocolPUP is the PARC Universal Packet protocol ethertype. + EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200 +) + +// Ethertypes holds the protocol numbers describing the payload of an ethernet +// frame. These types aren't necessarily supported by netstack, but can be used +// to catch all traffic of a type via packet endpoints. +var Ethertypes = []tcpip.NetworkProtocolNumber{ + EthernetProtocolAll, + EthernetProtocolPUP, +} + +// SourceAddress returns the "MAC source" field of the ethernet frame header. +func (b Ethernet) SourceAddress() tcpip.LinkAddress { + return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) +} + +// DestinationAddress returns the "MAC destination" field of the ethernet frame +// header. +func (b Ethernet) DestinationAddress() tcpip.LinkAddress { + return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize]) +} + +// Type returns the "ethertype" field of the ethernet frame header. +func (b Ethernet) Type() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:])) +} + +// Encode encodes all the fields of the ethernet frame header. +func (b Ethernet) Encode(e *EthernetFields) { + binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type)) + copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) + copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) +} + +// IsMulticastEthernetAddress returns true if the address is a multicast +// ethernet address. +func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool { + if len(addr) != EthernetAddressSize { + return false + } + + return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 +} + +// IsValidUnicastEthernetAddress returns true if the address is a unicast +// ethernet address. +func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + if len(addr) != EthernetAddressSize { + return false + } + + if addr == UnspecifiedEthernetAddress { + return false + } + + if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { + return false + } + + return true +} + +// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address +// for a multicast IPv4 address. +// +// addr MUST be a multicast IPv4 address. +func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress { + var linkAddrBytes [EthernetAddressSize]byte + // RFC 1112 Host Extensions for IP Multicasting + // + // 6.4. Extensions to an Ethernet Local Network Module: + // + // An IP host group address is mapped to an Ethernet multicast + // address by placing the low-order 23-bits of the IP address + // into the low-order 23 bits of the Ethernet multicast address + // 01-00-5E-00-00-00 (hex). + addrBytes := addr.As4() + linkAddrBytes[0] = 0x1 + linkAddrBytes[2] = 0x5e + linkAddrBytes[3] = addrBytes[1] & 0x7F + copy(linkAddrBytes[4:], addrBytes[IPv4AddressSize-2:]) + return tcpip.LinkAddress(linkAddrBytes[:]) +} + +// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address +// for a multicast IPv6 address. +// +// addr MUST be a multicast IPv6 address. +func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress { + // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks + // + // 7. Address Mapping -- Multicast + // + // An IPv6 packet with a multicast destination address DST, + // consisting of the sixteen octets DST[1] through DST[16], is + // transmitted to the Ethernet multicast address whose first + // two octets are the value 3333 hexadecimal and whose last + // four octets are the last four octets of DST. + addrBytes := addr.As16() + linkAddrBytes := []byte(addrBytes[IPv6AddressSize-EthernetAddressSize:]) + linkAddrBytes[0] = 0x33 + linkAddrBytes[1] = 0x33 + return tcpip.LinkAddress(linkAddrBytes[:]) +} diff --git a/internal/gtcpip/header/icmpv4.go b/internal/gtcpip/header/icmpv4.go new file mode 100644 index 0000000..580101c --- /dev/null +++ b/internal/gtcpip/header/icmpv4.go @@ -0,0 +1,228 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" +) + +// ICMPv4 represents an ICMPv4 header stored in a byte array. +type ICMPv4 []byte + +const ( + // ICMPv4PayloadOffset defines the start of ICMP payload. + ICMPv4PayloadOffset = 8 + + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 8 + + // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an + // errant packet's transport layer that an ICMP error type packet should + // attempt to send as per RFC 792 (see each type) and RFC 1122 + // section 3.2.2 which states: + // Every ICMP error message includes the Internet header and at + // least the first 8 data octets of the datagram that triggered + // the error; more than 8 octets MAY be sent; this header and data + // MUST be unchanged from the received datagram. + // + // RFC 792 shows: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Code | Checksum | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | unused | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Internet Header + 64 bits of Original Data Datagram | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ICMPv4MinimumErrorPayloadSize = 8 + + // ICMPv4ProtocolNumber is the ICMP transport protocol number. + ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 + + // icmpv4ChecksumOffset is the offset of the checksum field + // in an ICMPv4 message. + icmpv4ChecksumOffset = 2 + + // icmpv4MTUOffset is the offset of the MTU field + // in an ICMPv4FragmentationNeeded message. + icmpv4MTUOffset = 6 + + // icmpv4IdentOffset is the offset of the ident field + // in an ICMPv4EchoRequest/Reply message. + icmpv4IdentOffset = 4 + + // icmpv4PointerOffset is the offset of the pointer field + // in an ICMPv4ParamProblem message. + icmpv4PointerOffset = 4 + + // icmpv4SequenceOffset is the offset of the sequence field + // in an ICMPv4EchoRequest/Reply message. + icmpv4SequenceOffset = 6 +) + +// ICMPv4Type is the ICMP type field described in RFC 792. +type ICMPv4Type byte + +// ICMPv4Code is the ICMP code field described in RFC 792. +type ICMPv4Code byte + +// Typical values of ICMPv4Type defined in RFC 792. +const ( + ICMPv4EchoReply ICMPv4Type = 0 + ICMPv4DstUnreachable ICMPv4Type = 3 + ICMPv4SrcQuench ICMPv4Type = 4 + ICMPv4Redirect ICMPv4Type = 5 + ICMPv4Echo ICMPv4Type = 8 + ICMPv4TimeExceeded ICMPv4Type = 11 + ICMPv4ParamProblem ICMPv4Type = 12 + ICMPv4Timestamp ICMPv4Type = 13 + ICMPv4TimestampReply ICMPv4Type = 14 + ICMPv4InfoRequest ICMPv4Type = 15 + ICMPv4InfoReply ICMPv4Type = 16 +) + +// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. +const ( + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 +) + +// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792, +// RFC 1122 section 3.2.2.1 and RFC 1812 section 5.2.7.1. +const ( + ICMPv4NetUnreachable ICMPv4Code = 0 + ICMPv4HostUnreachable ICMPv4Code = 1 + ICMPv4ProtoUnreachable ICMPv4Code = 2 + ICMPv4PortUnreachable ICMPv4Code = 3 + ICMPv4FragmentationNeeded ICMPv4Code = 4 + ICMPv4SourceRouteFailed ICMPv4Code = 5 + ICMPv4DestinationNetworkUnknown ICMPv4Code = 6 + ICMPv4DestinationHostUnknown ICMPv4Code = 7 + ICMPv4SourceHostIsolated ICMPv4Code = 8 + ICMPv4NetProhibited ICMPv4Code = 9 + ICMPv4HostProhibited ICMPv4Code = 10 + ICMPv4NetUnreachableForTos ICMPv4Code = 11 + ICMPv4HostUnreachableForTos ICMPv4Code = 12 + ICMPv4AdminProhibited ICMPv4Code = 13 + ICMPv4HostPrecedenceViolation ICMPv4Code = 14 + ICMPv4PrecedenceCutInEffect ICMPv4Code = 15 +) + +// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed. +const ICMPv4UnusedCode ICMPv4Code = 0 + +// Type is the ICMP type field. +func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } + +// SetCode sets the ICMP code field. +func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } + +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + +// Checksum is the ICMP checksum field. +func (b ICMPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) +} + +// SetChecksum sets the ICMP checksum field. +func (b ICMPv4) SetChecksum(cs uint16) { + checksum.Put(b[icmpv4ChecksumOffset:], cs) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv4) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv4) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv4) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv4) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (b ICMPv4) Payload() []byte { + return b[ICMPv4PayloadOffset:] +} + +// MTU retrieves the MTU field from an ICMPv4 message. +func (b ICMPv4) MTU() uint16 { + return binary.BigEndian.Uint16(b[icmpv4MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv4 message. +func (b ICMPv4) SetMTU(mtu uint16) { + binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv4 message. +func (b ICMPv4) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv4IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv4 message. +func (b ICMPv4) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident) +} + +// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum. +func (b ICMPv4) SetIdentWithChecksumUpdate(new uint16) { + old := b.Ident() + b.SetIdent(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// Sequence retrieves the Sequence field from an ICMPv4 message. +func (b ICMPv4) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv4 message. +func (b ICMPv4) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence) +} + +// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, +// and payload. +func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 { + xsum := payloadCsum + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = checksum.Checksum(h[:2], xsum) + xsum = checksum.Checksum(h[4:], xsum) + + return ^xsum +} diff --git a/internal/gtcpip/header/icmpv6.go b/internal/gtcpip/header/icmpv6.go new file mode 100644 index 0000000..970f743 --- /dev/null +++ b/internal/gtcpip/header/icmpv6.go @@ -0,0 +1,304 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" +) + +// ICMPv6 represents an ICMPv6 header stored in a byte array. +type ICMPv6 []byte + +const ( + // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the + // sum of the size of the ICMPv6 Type, Code and Checksum fields, as + // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6 + // message body begins. + ICMPv6HeaderSize = 4 + + // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. + ICMPv6MinimumSize = 8 + + // ICMPv6PayloadOffset is the offset of the payload in an + // ICMP packet. + ICMPv6PayloadOffset = 8 + + // ICMPv6ProtocolNumber is the ICMP transport protocol number. + ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 + + // ICMPv6NeighborSolicitMinimumSize is the minimum size of a + // neighbor solicitation packet. + ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize + + // ICMPv6NeighborAdvertMinimumSize is the minimum size of a + // neighbor advertisement packet. + ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize + + // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet. + ICMPv6EchoMinimumSize = 8 + + // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header, + // as per RFC 4443, Appendix A, item 4 and the errata. + // ... all ICMP error messages shall have exactly + // 32 bits of type-specific data, so that receivers can reliably find + // the embedded invoking packet even when they don't recognize the + // ICMP message Type. + ICMPv6ErrorHeaderSize = 8 + + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP + // destination unreachable packet. + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + + // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP + // packet-too-big packet. + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + // ICMPv6ChecksumOffset is the offset of the checksum field + // in an ICMPv6 message. + ICMPv6ChecksumOffset = 2 + + // icmpv6PointerOffset is the offset of the pointer + // in an ICMPv6 Parameter problem message. + icmpv6PointerOffset = 4 + + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 + // PacketTooBig message. + icmpv6MTUOffset = 4 + + // icmpv6IdentOffset is the offset of the ident field + // in a ICMPv6 Echo Request/Reply message. + icmpv6IdentOffset = 4 + + // icmpv6SequenceOffset is the offset of the sequence field + // in a ICMPv6 Echo Request/Reply message. + icmpv6SequenceOffset = 6 + + // NDPHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + NDPHopLimit = 255 +) + +// ICMPv6Type is the ICMP type field described in RFC 4443. +type ICMPv6Type byte + +// Values for use in the Type field of ICMPv6 packet from RFC 4433. +const ( + ICMPv6DstUnreachable ICMPv6Type = 1 + ICMPv6PacketTooBig ICMPv6Type = 2 + ICMPv6TimeExceeded ICMPv6Type = 3 + ICMPv6ParamProblem ICMPv6Type = 4 + ICMPv6EchoRequest ICMPv6Type = 128 + ICMPv6EchoReply ICMPv6Type = 129 + + // Neighbor Discovery Protocol (NDP) messages, see RFC 4861. + + ICMPv6RouterSolicit ICMPv6Type = 133 + ICMPv6RouterAdvert ICMPv6Type = 134 + ICMPv6NeighborSolicit ICMPv6Type = 135 + ICMPv6NeighborAdvert ICMPv6Type = 136 + ICMPv6RedirectMsg ICMPv6Type = 137 + + // Multicast Listener Discovery (MLD) messages, see RFC 2710. + + ICMPv6MulticastListenerQuery ICMPv6Type = 130 + ICMPv6MulticastListenerReport ICMPv6Type = 131 + ICMPv6MulticastListenerDone ICMPv6Type = 132 + + // Multicast Listener Discovert Version 2 (MLDv2) messages, see RFC 3810. + + ICMPv6MulticastListenerV2Report ICMPv6Type = 143 +) + +// IsErrorType returns true if the receiver is an ICMP error type. +func (typ ICMPv6Type) IsErrorType() bool { + // Per RFC 4443 section 2.1: + // ICMPv6 messages are grouped into two classes: error messages and + // informational messages. Error messages are identified as such by a + // zero in the high-order bit of their message Type field values. Thus, + // error messages have message types from 0 to 127; informational + // messages have message types from 128 to 255. + return typ&0x80 == 0 +} + +// ICMPv6Code is the ICMP Code field described in RFC 4443. +type ICMPv6Code byte + +// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 +// section 3.1. +const ( + ICMPv6NetworkUnreachable ICMPv6Code = 0 + ICMPv6Prohibited ICMPv6Code = 1 + ICMPv6BeyondScope ICMPv6Code = 2 + ICMPv6AddressUnreachable ICMPv6Code = 3 + ICMPv6PortUnreachable ICMPv6Code = 4 + ICMPv6Policy ICMPv6Code = 5 + ICMPv6RejectRoute ICMPv6Code = 6 +) + +// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3. +const ( + ICMPv6HopLimitExceeded ICMPv6Code = 0 + ICMPv6ReassemblyTimeout ICMPv6Code = 1 +) + +// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. +const ( + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. + ICMPv6ErroneousHeader ICMPv6Code = 0 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 +) + +// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use +// the code field. (Types not mentioned above.) +const ICMPv6UnusedCode ICMPv6Code = 0 + +// Type is the ICMP type field. +func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } + +// SetCode sets the ICMP code field. +func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } + +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. +func (b ICMPv6) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) +} + +// Checksum is the ICMP checksum field. +func (b ICMPv6) Checksum() uint16 { + return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:]) +} + +// SetChecksum sets the ICMP checksum field. +func (b ICMPv6) SetChecksum(cs uint16) { + checksum.Put(b[ICMPv6ChecksumOffset:], cs) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv6) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv6) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv6) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv6) SetDestinationPort(uint16) { +} + +// MTU retrieves the MTU field from an ICMPv6 message. +func (b ICMPv6) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +// SetMTU sets the MTU field from an ICMPv6 message. +func (b ICMPv6) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +// Ident retrieves the Ident field from an ICMPv6 message. +func (b ICMPv6) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +// SetIdent sets the Ident field from an ICMPv6 message. +func (b ICMPv6) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum. +func (b ICMPv6) SetIdentWithChecksumUpdate(new uint16) { + old := b.Ident() + b.SetIdent(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// Sequence retrieves the Sequence field from an ICMPv6 message. +func (b ICMPv6) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +// SetSequence sets the Sequence field from an ICMPv6 message. +func (b ICMPv6) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + +// MessageBody returns the message body as defined by RFC 4443 section 2.1; the +// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) MessageBody() []byte { + return b[ICMPv6HeaderSize:] +} + +// Payload implements Transport.Payload. +func (b ICMPv6) Payload() []byte { + return b[ICMPv6PayloadOffset:] +} + +// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. +type ICMPv6ChecksumParams struct { + Header ICMPv6 + Src tcpip.Address + Dst tcpip.Address + PayloadCsum uint16 + PayloadLen int +} + +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, +// IPv6 src/dst addresses and the payload. +func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { + h := params.Header + + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen)) + xsum = checksum.Combine(xsum, params.PayloadCsum) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = checksum.Checksum(h[:2], xsum) + xsum = checksum.Checksum(h[4:], xsum) + + return ^xsum +} + +// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an +// updated address in the pseudo header. +func (b ICMPv6) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), old, new)) +} diff --git a/internal/gtcpip/header/ipv4.go b/internal/gtcpip/header/ipv4.go new file mode 100644 index 0000000..d76db68 --- /dev/null +++ b/internal/gtcpip/header/ipv4.go @@ -0,0 +1,1205 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "net/netip" + "time" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing/common" +) + +// RFC 971 defines the fields of the IPv4 header on page 11 using the following +// diagram: ("Figure 4") +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification |Flags| Fragment Offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Time to Live | Protocol | Header Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Source Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Destination Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options | Padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +const ( + versIHL = 0 + tos = 1 + // IPv4TotalLenOffset is the offset of the total length field in the + // IPv4 header. + IPv4TotalLenOffset = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + xsum = 10 + srcAddr = 12 + dstAddr = 16 + options = 20 +) + +// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the +// fields of a packet that needs to be encoded. The IHL field is not here as +// it is totally defined by the size of the options. +type IPv4Fields struct { + // TOS is the "type of service" field of an IPv4 packet. + TOS uint8 + + // TotalLength is the "total length" field of an IPv4 packet. + TotalLength uint16 + + // ID is the "identification" field of an IPv4 packet. + ID uint16 + + // Flags is the "flags" field of an IPv4 packet. + Flags uint8 + + // FragmentOffset is the "fragment offset" field of an IPv4 packet. + FragmentOffset uint16 + + // TTL is the "time to live" field of an IPv4 packet. + TTL uint8 + + // Protocol is the "protocol" field of an IPv4 packet. + Protocol uint8 + + // Checksum is the "checksum" field of an IPv4 packet. + Checksum uint16 + + // SrcAddr is the "source ip address" of an IPv4 packet. + SrcAddr netip.Addr + + // DstAddr is the "destination ip address" of an IPv4 packet. + DstAddr netip.Addr + + // Options must be 40 bytes or less as they must fit along with the + // rest of the IPv4 header into the maximum size describable in the + // IHL field. RFC 791 section 3.1 says: + // IHL: 4 bits + // + // Internet Header Length is the length of the internet header in 32 + // bit words, and thus points to the beginning of the data. Note that + // the minimum value for a correct header is 5. + // + // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode + // more will fail. + Options IPv4OptionsSerializer +} + +// IPv4 is an IPv4 header. +// Most of the methods of IPv4 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv4 before using other +// methods. +type IPv4 []byte + +const ( + // IPv4MinimumSize is the minimum size of a valid IPv4 packet; + // i.e. a packet header with no options. + IPv4MinimumSize = 20 + + // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. + IPv4MaximumHeaderSize = 60 + + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. + // + // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 + // header size). But RFC 791 section 3.2 discusses the design of the IPv4 + // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of + // 65,536 octets. Note that this is consistent with the datagram total + // length field (of course, the header is counted in the total length and not + // in the fragments)." + IPv4MaximumPayloadSize = 65536 + + // MinIPFragmentPayloadSize is the minimum number of payload bytes that + // the first fragment must carry when an IPv4 packet is fragmented. + MinIPFragmentPayloadSize = 8 + + // IPv4AddressSize is the size, in bytes, of an IPv4 address. + IPv4AddressSize = 4 + + // IPv4AddressSizeBits is the size, in bits, of an IPv4 address. + IPv4AddressSizeBits = 32 + + // IPv4ProtocolNumber is IPv4's network protocol number. + IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 + + // IPv4Version is the version of the IPv4 protocol. + IPv4Version = 4 + + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP + // packet that every IPv4 capable host must be able to + // process/reassemble. + IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 +) + +var ( + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems = tcpip.AddrFrom4([4]byte{0xe0, 0x00, 0x00, 0x01}) + + // IPv4Broadcast is the broadcast address of the IPv4 procotol. + IPv4Broadcast = tcpip.AddrFrom4([4]byte{0xff, 0xff, 0xff, 0xff}) + + // IPv4Any is the non-routable IPv4 "any" meta address. + IPv4Any = tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}) + + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup = tcpip.AddrFrom4([4]byte{0xe0, 0x00, 0x00, 0x02}) + + // IPv4Loopback is the loopback IPv4 address. + IPv4Loopback = tcpip.AddrFrom4([4]byte{0x7f, 0x00, 0x00, 0x01}) +) + +// Flags that may be set in an IPv4 packet. +const ( + IPv4FlagMoreFragments = 1 << iota + IPv4FlagDontFragment +) + +// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined +// by RFC 3927 section 1. +var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0xa9, 0xfe, 0x00, 0x00}), tcpip.MaskFrom("\xff\xff\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as +// defined by RFC 5771 section 4. +var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0xe0, 0x00, 0x00, 0x00}), tcpip.MaskFrom("\xff\xff\xff\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// IPv4EmptySubnet is the empty IPv4 subnet. +var IPv4EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.MaskFrom("\x00\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// IPv4CurrentNetworkSubnet is the subnet of addresses for the current network, +// per RFC 6890 section 2.2.2, +// +// +----------------------+----------------------------+ +// | Attribute | Value | +// +----------------------+----------------------------+ +// | Address Block | 0.0.0.0/8 | +// | Name | "This host on this network"| +// | RFC | [RFC1122], Section 3.2.1.3 | +// | Allocation Date | September 1981 | +// | Termination Date | N/A | +// | Source | True | +// | Destination | False | +// | Forwardable | False | +// | Global | False | +// | Reserved-by-Protocol | True | +// +----------------------+----------------------------+ +var IPv4CurrentNetworkSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.MaskFrom("\xff\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// IPv4LoopbackSubnet is the loopback subnet for IPv4. +var IPv4LoopbackSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x7f, 0x00, 0x00, 0x00}), tcpip.MaskFrom("\xff\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// IPVersion returns the version of IP used in the given packet. It returns -1 +// if the packet is not large enough to contain the version field. +func IPVersion(b []byte) int { + // Length must be at least offset+length of version field. + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> ipVersionShift) +} + +// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits +// of the first byte, and is counted in multiples of 4 bytes. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// (...) +// Version: 4 bits +// The Version field indicates the format of the internet header. This +// document describes version 4. +// +// IHL: 4 bits +// Internet Header Length is the length of the internet header in 32 +// bit words, and thus points to the beginning of the data. Note that +// the minimum value for a correct header is 5. +const ( + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + +// HeaderLength returns the value of the "header length" field of the IPv4 +// header. The length returned is in bytes. +func (b IPv4) HeaderLength() uint8 { + return (b[versIHL] & ipIHLMask) * IPv4IHLStride +} + +// SetHeaderLength sets the value of the "Internet Header Length" field. +func (b IPv4) SetHeaderLength(hdrLen uint8) { + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize)) + } + b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) +} + +// ID returns the value of the identifier field of the IPv4 header. +func (b IPv4) ID() uint16 { + return binary.BigEndian.Uint16(b[id:]) +} + +// Protocol returns the value of the protocol field of the IPv4 header. +func (b IPv4) Protocol() uint8 { + return b[protocol] +} + +// Flags returns the "flags" field of the IPv4 header. +func (b IPv4) Flags() uint8 { + return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) +} + +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + +// TTL returns the "TTL" field of the IPv4 header. +func (b IPv4) TTL() uint8 { + return b[ttl] +} + +// FragmentOffset returns the "fragment offset" field of the IPv4 header. +func (b IPv4) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[flagsFO:]) << 3 +} + +// TotalLength returns the "total length" field of the IPv4 header. +func (b IPv4) TotalLength() uint16 { + return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) +} + +// Checksum returns the checksum field of the IPv4 header. +func (b IPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[xsum:]) +} + +// SourceAddress returns the "source address" field of the IPv4 header. +func (b IPv4) SourceAddress() tcpip.Address { + return tcpip.AddrFrom4([4]byte(b[srcAddr : srcAddr+IPv4AddressSize])) +} + +// DestinationAddress returns the "destination address" field of the IPv4 +// header. +func (b IPv4) DestinationAddress() tcpip.Address { + return tcpip.AddrFrom4([4]byte(b[dstAddr : dstAddr+IPv4AddressSize])) +} + +// SourceAddressSlice returns the "source address" field of the IPv4 header as a +// byte slice. +func (b IPv4) SourceAddressSlice() []byte { + return []byte(b[srcAddr : srcAddr+IPv4AddressSize]) +} + +// DestinationAddressSlice returns the "destination address" field of the IPv4 +// header as a byte slice. +func (b IPv4) DestinationAddressSlice() []byte { + return []byte(b[dstAddr : dstAddr+IPv4AddressSize]) +} + +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + +// padIPv4OptionsLength returns the total length for IPv4 options of length l +// after applying padding according to RFC 791: +// +// The internet header padding is used to ensure that the internet +// header ends on a 32 bit boundary. +func padIPv4OptionsLength(length uint8) uint8 { + return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1) +} + +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { + hdrLen := b.HeaderLength() + return IPv4Options(b[options:hdrLen:hdrLen]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.Protocol()) +} + +// Payload implements Network.Payload. +func (b IPv4) Payload() []byte { + return b[b.HeaderLength():][:b.PayloadLength()] +} + +// PayloadLength returns the length of the payload portion of the IPv4 packet. +func (b IPv4) PayloadLength() uint16 { + return b.TotalLength() - uint16(b.HeaderLength()) +} + +// TOS returns the "type of service" field of the IPv4 header. +func (b IPv4) TOS() (uint8, uint32) { + return b[tos], 0 +} + +// SetTOS sets the "type of service" field of the IPv4 header. +func (b IPv4) SetTOS(v uint8, _ uint32) { + b[tos] = v +} + +// SetTTL sets the "Time to Live" field of the IPv4 header. +func (b IPv4) SetTTL(v byte) { + b[ttl] = v +} + +// SetTotalLength sets the "total length" field of the IPv4 header. +func (b IPv4) SetTotalLength(totalLength uint16) { + binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) +} + +// SetChecksum sets the checksum field of the IPv4 header. +func (b IPv4) SetChecksum(v uint16) { + checksum.Put(b[xsum:], v) +} + +// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the +// IPv4 header. +func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { + v := (uint16(flags) << 13) | (offset >> 3) + binary.BigEndian.PutUint16(b[flagsFO:], v) +} + +// SetID sets the identification field. +func (b IPv4) SetID(v uint16) { + binary.BigEndian.PutUint16(b[id:], v) +} + +// SetSourceAddress sets the "source address" field of the IPv4 header. +func (b IPv4) SetSourceAddress(addr tcpip.Address) { + copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.AsSlice()) +} + +// SetDestinationAddress sets the "destination address" field of the IPv4 +// header. +func (b IPv4) SetDestinationAddress(addr tcpip.Address) { + copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.AsSlice()) +} + +// CalculateChecksum calculates the checksum of the IPv4 header. +func (b IPv4) CalculateChecksum() uint16 { + return checksum.Checksum(b[:b.HeaderLength()], 0) +} + +// Encode encodes all the fields of the IPv4 header. +func (b IPv4) Encode(i *IPv4Fields) { + // The size of the options defines the size of the whole header and thus the + // IHL field. Options are rare and this is a heavily used function so it is + // worth a bit of optimisation here to keep the serializer out of the fast + // path. + hdrLen := uint8(IPv4MinimumSize) + if len(i.Options) != 0 { + hdrLen += i.Options.Serialize(b[options:]) + } + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize)) + } + b.SetHeaderLength(hdrLen) + b[tos] = i.TOS + b.SetTotalLength(i.TotalLength) + binary.BigEndian.PutUint16(b[id:], i.ID) + b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) + b[ttl] = i.TTL + b[protocol] = i.Protocol + b.SetChecksum(i.Checksum) + copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr.AsSlice()) + copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr.AsSlice()) +} + +// EncodePartial updates the total length and checksum fields of IPv4 header, +// taking in the partial checksum, which is the checksum of the header without +// the total length and checksum fields. It is useful in cases when similar +// packets are produced. +func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { + b.SetTotalLength(totalLength) + xsum := checksum.Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum) + b.SetChecksum(^xsum) +} + +// IsValid performs basic validation on the packet. +func (b IPv4) IsValid(pktSize int) bool { + if len(b) < IPv4MinimumSize { + return false + } + + hlen := int(b.HeaderLength()) + tlen := int(b.TotalLength()) + if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize { + return false + } + + if IPVersion(b) != IPv4Version { + return false + } + + return true +} + +// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4 +// link-local unicast address. +func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalUnicastSubnet.Contains(addr) +} + +// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4 +// link-local multicast address. +func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalMulticastSubnet.Contains(addr) +} + +// IsChecksumValid returns true iff the IPv4 header's checksum is valid. +func (b IPv4) IsChecksumValid() bool { + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + return b.CalculateChecksum() == 0xffff +} + +// IsV4MulticastAddress determines if the provided address is an IPv4 multicast +// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits +// will be 1110 = 0xe0. +func IsV4MulticastAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv4AddressSizeBits { + return false + } + addrBytes := addr.As4() + return (addrBytes[0] & 0xf0) == 0xe0 +} + +// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback +// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3. +func IsV4LoopbackAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv4AddressSizeBits { + return false + } + addrBytes := addr.As4() + return addrBytes[0] == 0x7f +} + +// ========================= Options ========================== + +// An IPv4OptionType can hold the value for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRouterAlertType is the option type for the Router Alert option, + // defined in RFC 2113 Section 2.1. + IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// IPv4OptParameterProblem indicates that a Parameter Problem message +// should be generated, and gives the offset in the current entity +// that should be used in that packet. +type IPv4OptParameterProblem struct { + Pointer uint8 + NeedICMP bool +} + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return *o } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +// TODO(https://gvisor.dev/issues/5513): Add unit tests for IPv4OptionIterator. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// InitReplacement copies the option into the new option buffer. +func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options { + replacementOption := i.RemainingBuffer()[:option.Size()] + if copied := copy(replacementOption, option.Contents()); copied != len(replacementOption) { + panic(fmt.Sprintf("copied %d bytes in the replacement option buffer, expected %d bytes", copied, len(replacementOption))) + } + return replacementOption +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return i.newOptions[i.writePoint:] +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// Undefined in the case of error. +// - An error indication which is non-nil if an error condition was found. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) { + // The actual error is in the length (2nd byte of the option) but we + // return the start of the option for compatibility with Linux. + + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + + case IPv4OptionRouterAlertType: + if optLen != IPv4OptionRouterAlertLength { + i.ErrCursor++ + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } + } + retval := IPv4OptionRouterAlert(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + // Per RFC 791 page 21: + // The Timestamp is a right-justified, 32-bit timestamp in + // milliseconds since midnight UT. + now := clock.Now().UTC() + midnight := now.Truncate(24 * time.Hour) + return uint32(now.Sub(midnight).Milliseconds()) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr.AsSlice()); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.AddrFrom4([4]byte(slot[:IPv4AddressSize])) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr.AsSlice()); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr } + +// Router Alert option specific related constants. +// +// from RFC 2113 section 2.1: +// +// +--------+--------+--------+--------+ +// |10010100|00000100| 2 octet value | +// +--------+--------+--------+--------+ +// +// Type: +// Copied flag: 1 (all fragments must carry the option) +// Option class: 0 (control) +// Option number: 20 (decimal) +// +// Length: 4 +// +// Value: A two octet code with the following values: +// 0 - Router shall examine packet +// 1-65535 - Reserved +const ( + // IPv4OptionRouterAlertLength is the length of a Router Alert option. + IPv4OptionRouterAlertLength = 4 + + // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit + // payload of the router alert option. + IPv4OptionRouterAlertValue = 0 + + // IPv4OptionRouterAlertValueOffset is the offset for the value of a + // RouterAlert option. + IPv4OptionRouterAlertValueOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRouterAlert)(nil) + +// IPv4OptionRouterAlert is an IPv4 RouterAlert option defined by RFC 2113. +type IPv4OptionRouterAlert []byte + +// Type implements IPv4Option. +func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAlertType } + +// Size implements IPv4Option. +func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) } + +// Contents implements IPv4Option. +func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra } + +// Value returns the value of the IPv4OptionRouterAlert. +func (ra *IPv4OptionRouterAlert) Value() uint16 { + return binary.BigEndian.Uint16(ra.Contents()[IPv4OptionRouterAlertValueOffset:]) +} + +// IPv4SerializableOption is an interface to represent serializable IPv4 option +// types. +type IPv4SerializableOption interface { + // optionType returns the type identifier of the option. + optionType() IPv4OptionType +} + +// IPv4SerializableOptionPayload is an interface providing serialization of the +// payload of an IPv4 option. +type IPv4SerializableOptionPayload interface { + // length returns the size of the payload. + length() uint8 + + // serializeInto serializes the payload into the provided byte buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MUST panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto(buffer []byte) uint8 +} + +// IPv4OptionsSerializer is a serializer for IPv4 options. +type IPv4OptionsSerializer []IPv4SerializableOption + +// Length returns the total number of bytes required to serialize the options. +func (s IPv4OptionsSerializer) Length() uint8 { + var total uint8 + for _, opt := range s { + total++ + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Add 1 to reported length to account for the length byte. + total += 1 + withPayload.length() + } + } + return padIPv4OptionsLength(total) +} + +// Serialize serializes the provided list of IPV4 options into b. +// +// Note, b must be of sufficient size to hold all the options in s. See +// IPv4OptionsSerializer.Length for details on the getting the total size +// of a serialized IPv4OptionsSerializer. +// +// Serialize panics if b is not of sufficient size to hold all the options in s. +func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 { + var total uint8 + for _, opt := range s { + ty := opt.optionType() + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Serialize first to reduce bounds checks. + l := 2 + withPayload.serializeInto(b[2:]) + b[0] = byte(ty) + b[1] = l + b = b[l:] + total += l + continue + } + // Options without payload consist only of the type field. + // + // NB: Repeating code from the branch above is intentional to minimize + // bounds checks. + b[0] = byte(ty) + b = b[1:] + total++ + } + + // According to RFC 791: + // + // The internet header padding is used to ensure that the internet + // header ends on a 32 bit boundary. The padding is zero. + padded := padIPv4OptionsLength(total) + b = b[:padded-total] + common.ClearArray(b) + return padded +} + +var ( + _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil) + _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil) +) + +// IPv4SerializableRouterAlertOption provides serialization of the Router Alert +// IPv4 option according to RFC 2113. +type IPv4SerializableRouterAlertOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType { + return IPv4OptionRouterAlertType +} + +// Length implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) length() uint8 { + return IPv4OptionRouterAlertLength - IPv4OptionRouterAlertValueOffset +} + +// SerializeInto implements IPv4SerializableOption. +func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 { + binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue) + return o.length() +} + +var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil) + +// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option. +type IPv4SerializableNOPOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableNOPOption) optionType() IPv4OptionType { + return IPv4OptionNOPType +} + +var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil) + +// IPv4SerializableListEndOption provides serialization for the IPv4 List End +// option. +type IPv4SerializableListEndOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableListEndOption) optionType() IPv4OptionType { + return IPv4OptionListEndType +} diff --git a/internal/gtcpip/header/ipv6.go b/internal/gtcpip/header/ipv6.go new file mode 100644 index 0000000..1a5a7a0 --- /dev/null +++ b/internal/gtcpip/header/ipv6.go @@ -0,0 +1,578 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "crypto/sha256" + "encoding/binary" + "fmt" + "net/netip" + + "github.com/sagernet/sing-tun/internal/gtcpip" +) + +const ( + versTCFL = 0 + // IPv6PayloadLenOffset is the offset of the PayloadLength field in + // IPv6 header. + IPv6PayloadLenOffset = 4 + // IPv6NextHeaderOffset is the offset of the NextHeader field in + // IPv6 header. + IPv6NextHeaderOffset = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize +) + +// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv6Fields struct { + // TrafficClass is the "traffic class" field of an IPv6 packet. + TrafficClass uint8 + + // FlowLabel is the "flow label" field of an IPv6 packet. + FlowLabel uint32 + + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. + PayloadLength uint16 + + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber + + // HopLimit is the "Hop Limit" field of an IPv6 packet. + HopLimit uint8 + + // SrcAddr is the "source ip address" of an IPv6 packet. + SrcAddr netip.Addr + + // DstAddr is the "destination ip address" of an IPv6 packet. + DstAddr netip.Addr + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer +} + +// IPv6 represents an ipv6 header stored in a byte array. +// Most of the methods of IPv6 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv6 before using other methods. +type IPv6 []byte + +const ( + // IPv6MinimumSize is the minimum size of a valid IPv6 packet. + IPv6MinimumSize = IPv6FixedHeaderSize + + // IPv6AddressSize is the size, in bytes, of an IPv6 address. + IPv6AddressSize = 16 + + // IPv6AddressSizeBits is the size, in bits, of an IPv6 address. + IPv6AddressSizeBits = 128 + + // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per + // RFC 8200 Section 4.5. + IPv6MaximumPayloadSize = 65535 + + // IPv6ProtocolNumber is IPv6's network protocol number. + IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd + + // IPv6Version is the version of the ipv6 protocol. + IPv6Version = 6 + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. + IPv6MinimumMTU = 1280 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 + + // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field + // is located within a multicast IPv6 address, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeByteIdx = 1 + + // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, + // within the byte holding the field, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeMask = 0xF +) + +var ( + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) + + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}) + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers on a link. + // + // The address is ff02::2. + IPv6AllRoutersLinkLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}) + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}) + + // IPv6Loopback is the IPv6 Loopback address. + IPv6Loopback = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) + + // IPv6Any is the non-routable IPv6 "any" meta address. It is also + // known as the unspecified address. + IPv6Any = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) +) + +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. +var IPv6EmptySubnet = tcpip.AddressWithPrefix{ + Address: IPv6Any, + PrefixLen: 0, +}.Subnet() + +// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined +// by RFC 4291 section 2.5.5. +var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}), + PrefixLen: 96, +}.Subnet() + +// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined +// by RFC 4291 section 2.5.6. +// +// The prefix is fe80::/64 +var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16([16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + PrefixLen: 64, +} + +// PayloadLength returns the value of the "payload length" field of the ipv6 +// header. +func (b IPv6) PayloadLength() uint16 { + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) +} + +// HopLimit returns the value of the "Hop Limit" field of the ipv6 header. +func (b IPv6) HopLimit() uint8 { + return b[hopLimit] +} + +// NextHeader returns the value of the "next header" field of the ipv6 header. +func (b IPv6) NextHeader() uint8 { + return b[IPv6NextHeaderOffset] +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// Payload implements Network.Payload. +func (b IPv6) Payload() []byte { + return b[IPv6MinimumSize:][:b.PayloadLength()] +} + +// SourceAddress returns the "source address" field of the ipv6 header. +func (b IPv6) SourceAddress() tcpip.Address { + return tcpip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize])) +} + +// DestinationAddress returns the "destination address" field of the ipv6 +// header. +func (b IPv6) DestinationAddress() tcpip.Address { + return tcpip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize])) +} + +// SourceAddressSlice returns the "source address" field of the ipv6 header as a +// byte slice. +func (b IPv6) SourceAddressSlice() []byte { + return []byte(b[v6SrcAddr:][:IPv6AddressSize]) +} + +// DestinationAddressSlice returns the "destination address" field of the ipv6 +// header as a byte slice. +func (b IPv6) DestinationAddressSlice() []byte { + return []byte(b[v6DstAddr:][:IPv6AddressSize]) +} + +// Checksum implements Network.Checksum. Given that IPv6 doesn't have a +// checksum, it just returns 0. +func (IPv6) Checksum() uint16 { + return 0 +} + +// TOS returns the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) TOS() (uint8, uint32) { + v := binary.BigEndian.Uint32(b[versTCFL:]) + return uint8(v >> 20), v & 0xfffff +} + +// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) SetTOS(t uint8, l uint32) { + vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) + binary.BigEndian.PutUint32(b[versTCFL:], vtf) +} + +// SetPayloadLength sets the "payload length" field of the ipv6 header. +func (b IPv6) SetPayloadLength(payloadLength uint16) { + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) +} + +// SetSourceAddress sets the "source address" field of the ipv6 header. +func (b IPv6) SetSourceAddress(addr tcpip.Address) { + copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice()) +} + +// SetDestinationAddress sets the "destination address" field of the ipv6 +// header. +func (b IPv6) SetDestinationAddress(addr tcpip.Address) { + copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice()) +} + +// SetHopLimit sets the value of the "Hop Limit" field. +func (b IPv6) SetHopLimit(v uint8) { + b[hopLimit] = v +} + +// SetNextHeader sets the value of the "next header" field of the ipv6 header. +func (b IPv6) SetNextHeader(v uint8) { + b[IPv6NextHeaderOffset] = v +} + +// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a +// checksum, it is empty. +func (IPv6) SetChecksum(uint16) { +} + +// Encode encodes all the fields of the ipv6 header. +func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] + b.SetTOS(i.TrafficClass, i.FlowLabel) + b.SetPayloadLength(i.PayloadLength) + b[hopLimit] = i.HopLimit + b.SetSourceAddr(i.SrcAddr) + b.SetDestinationAddr(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader +} + +// IsValid performs basic validation on the packet. +func (b IPv6) IsValid(pktSize int) bool { + if len(b) < IPv6MinimumSize { + return false + } + + dlen := int(b.PayloadLength()) + if dlen > pktSize-IPv6MinimumSize { + return false + } + + if IPVersion(b) != IPv6Version { + return false + } + + return true +} + +// IsV4MappedAddress determines if the provided address is an IPv4 mapped +// address by checking if its prefix is 0:0:0:0:0:ffff::/96. +func IsV4MappedAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv6AddressSizeBits { + return false + } + + return IPv4MappedIPv6Subnet.Contains(addr) +} + +// IsV6MulticastAddress determines if the provided address is an IPv6 +// multicast address (anything starting with FF). +func IsV6MulticastAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv6AddressSizeBits { + return false + } + return addr.As16()[0] == 0xff +} + +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv6AddressSizeBits { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr.As16()[0] != 0xff +} + +var solicitedNodeMulticastPrefix = [13]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff} + +// SolicitedNodeAddr computes the solicited-node multicast address. This is +// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 +// address. +func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { + addrBytes := addr.As16() + return tcpip.AddrFrom16([16]byte(append(solicitedNodeMulticastPrefix[:], addrBytes[len(addrBytes)-3:]...))) +} + +// IsSolicitedNodeAddr determines whether the address is a solicited-node +// multicast address. +func IsSolicitedNodeAddr(addr tcpip.Address) bool { + addrBytes := addr.As16() + return solicitedNodeMulticastPrefix == [13]byte(addrBytes[:len(addrBytes)-3]) +} + +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + +// LinkLocalAddr computes the default IPv6 link-local address from a link-layer +// (MAC) address. +func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. + // + // The conversion is very nearly: + // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff + // Note the capital A. The conversion aa->Aa involves a bit flip. + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) + return tcpip.AddrFrom16(lladdrb) +} + +// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6 +// link-local unicast address, as defined by RFC 4291 section 2.5.6. +func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool { + if addr.BitLen() != IPv6AddressSizeBits { + return false + } + addrBytes := addr.As16() + return addrBytes[0] == 0xfe && (addrBytes[1]&0xc0) == 0x80 +} + +// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback +// address, as defined by RFC 4291 section 2.5.3. +func IsV6LoopbackAddress(addr tcpip.Address) bool { + return addr == IPv6Loopback +} + +// IsV6LinkLocalMulticastAddress returns true iff the provided address is an +// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7. +func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { + return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope +} + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + prefixID := prefix.ID() + h.Write([]byte(prefixID.AsSlice()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.AddrFrom16([16]byte(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))) +} + +// IPv6AddressScope is the scope of an IPv6 address. +type IPv6AddressScope int + +const ( + // LinkLocalScope indicates a link-local address. + LinkLocalScope IPv6AddressScope = iota + + // GlobalScope indicates a global address. + GlobalScope +) + +// ScopeForIPv6Address returns the scope for an IPv6 address. +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { + if addr.BitLen() != IPv6AddressSizeBits { + return GlobalScope, &tcpip.ErrBadAddress{} + } + + switch { + case IsV6LinkLocalMulticastAddress(addr): + return LinkLocalScope, nil + + case IsV6LinkLocalUnicastAddress(addr): + return LinkLocalScope, nil + + default: + return GlobalScope, nil + } +} + +// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an +// associated stable/permanent SLAAC address. +// +// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be +// used when generating a new temporary IID. +// +// Panics if tempIIDHistory is not at least IIDSize bytes. +func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix { + addrBytes := stableAddr.As16() + h := sha256.New() + h.Write(tempIIDHistory) + h.Write(addrBytes[IIDOffsetInIPv6Address:]) + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + // The rightmost 64 bits of sum are saved for the next iteration. + if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } + + // The leftmost 64 bits of sum is used as the IID. + if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize { + panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize)) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16(addrBytes), + PrefixLen: IIDOffsetInIPv6Address * 8, + } +} + +// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by +// RFC 7346 section 2. +type IPv6MulticastScope uint8 + +// The various values for IPv6 multicast scopes, as per RFC 7346 section 2: +// +// +------+--------------------------+-------------------------+ +// | scop | NAME | REFERENCE | +// +------+--------------------------+-------------------------+ +// | 0 | Reserved | [RFC4291], RFC 7346 | +// | 1 | Interface-Local scope | [RFC4291], RFC 7346 | +// | 2 | Link-Local scope | [RFC4291], RFC 7346 | +// | 3 | Realm-Local scope | [RFC4291], RFC 7346 | +// | 4 | Admin-Local scope | [RFC4291], RFC 7346 | +// | 5 | Site-Local scope | [RFC4291], RFC 7346 | +// | 6 | Unassigned | | +// | 7 | Unassigned | | +// | 8 | Organization-Local scope | [RFC4291], RFC 7346 | +// | 9 | Unassigned | | +// | A | Unassigned | | +// | B | Unassigned | | +// | C | Unassigned | | +// | D | Unassigned | | +// | E | Global scope | [RFC4291], RFC 7346 | +// | F | Reserved | [RFC4291], RFC 7346 | +// +------+--------------------------+-------------------------+ +const ( + IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0) + IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1) + IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2) + IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3) + IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4) + IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5) + IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8) + IPv6GlobalMulticastScope = IPv6MulticastScope(0xE) + IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF) +) + +// V6MulticastScope returns the scope of a multicast address. +func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope { + addrBytes := addr.As16() + return IPv6MulticastScope(addrBytes[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask) +} diff --git a/internal/gtcpip/header/ipv6_extension_headers.go b/internal/gtcpip/header/ipv6_extension_headers.go new file mode 100644 index 0000000..3ab135d --- /dev/null +++ b/internal/gtcpip/header/ipv6_extension_headers.go @@ -0,0 +1,955 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing/common" +) + +// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier. +type IPv6ExtensionHeaderIdentifier uint8 + +const ( + // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by + // Hop Options extension header, as per RFC 8200 section 4.3. + IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0 + + // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension + // header, as per RFC 8200 section 4.4. + IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43 + + // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment + // extension header, as per RFC 8200 section 4.5. + IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44 + + // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a + // Destination Options extension header, as per RFC 8200 section 4.6. + IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60 + + // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end + // of an IPv6 payload, as per RFC 8200 section 4.7. + IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 + + // IPv6UnknownExtHdrIdentifier is reserved by IANA. + // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header + // "254 Use for experimentation and testing [RFC3692][RFC4727]" + IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254 +) + +const ( + // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when + // a node encounters an unrecognized option. + ipv6UnknownExtHdrOptionActionMask = 192 + + // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard + // from the action value for an unrecognized option identifier. + ipv6UnknownExtHdrOptionActionShift = 6 + + // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field + // within an IPv6RoutingExtHdr. + ipv6RoutingExtHdrSegmentsLeftIdx = 1 + + // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in + // bytes. + IPv6FragmentExtHdrLength = 8 + + // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the + // Fragment Offset field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrFragmentOffsetOffset = 0 + + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrFragmentOffsetShift = 3 + + // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an + // IPv6FragmentExtHdr. + ipv6FragmentExtHdrFlagsIdx = 1 + + // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the + // flags field of an IPv6FragmentExtHdr. + ipv6FragmentExtHdrMFlagMask = 1 + + // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification + // field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrIdentificationOffset = 2 + + // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length + // field. That is, given a Length field of 2, the extension header expects + // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for + // details about the first 8 bytes' exclusion from the Length field). + ipv6ExtHdrLenBytesPerUnit = 8 + + // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an + // extension header's Length field following the Length field. + // + // The Length field excludes the first 8 bytes, but the Next Header and Length + // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes + // after the Length field. + // + // This ensures that every extension header is at least 8 bytes. + ipv6ExtHdrLenBytesExcluded = 6 + + // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment + // extension header's Fragment Offset field. That is, given a Fragment Offset + // of 2, the extension header is indicating that the fragment's payload + // starts at the 16th byte in the reassembled packet. + IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 +) + +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + common.ClearArray(s) + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + +// IPv6PayloadHeader is implemented by the various headers that can be found +// in an IPv6 payload. +// +// These headers include IPv6 extension headers or upper layer data. +type IPv6PayloadHeader interface { + isIPv6PayloadHeader() + + // Release frees all resources held by the header. + Release() +} + +// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator +// encounters a Next Header field it does not recognize as an IPv6 extension +// header. The caller is responsible for releasing the underlying buffer after +// it's no longer needed. +type IPv6RawPayloadHeader struct { + Identifier IPv6ExtensionHeaderIdentifier + Buf buffer.Buffer +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {} + +// Release implements IPv6PayloadHeader.Release. +func (i IPv6RawPayloadHeader) Release() { + i.Buf.Release() +} + +// ipv6OptionsExtHdr is an IPv6 extension header that holds options. +type ipv6OptionsExtHdr struct { + buf *buffer.View +} + +// Release implements IPv6PayloadHeader.Release. +func (i ipv6OptionsExtHdr) Release() { + if i.buf != nil { + i.buf.Release() + } +} + +// Iter returns an iterator over the IPv6 extension header options held in b. +func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { + it := IPv6OptionsExtHdrOptionsIterator{} + it.reader = i.buf + return it +} + +// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header +// options. +// +// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last +// used, no changes to the underlying buffer may happen. Doing so may cause +// undefined and unexpected behaviour. It is fine to obtain an +// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then +// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator +// obtained before modification is no longer used. +type IPv6OptionsExtHdrOptionsIterator struct { + reader *buffer.View + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset +} + +// IPv6OptionUnknownAction is the action that must be taken if the processing +// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2. +type IPv6OptionUnknownAction int + +const ( + // IPv6OptionUnknownActionSkip indicates that the unrecognized option must + // be skipped and the node should continue processing the header. + IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0 + + // IPv6OptionUnknownActionDiscard indicates that the packet must be silently + // discarded. + IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1 + + // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be + // discarded and the node must send an ICMP Parameter Problem, Code 2, message + // to the packet's source, regardless of whether or not the packet's + // Destination was a multicast address. + IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2 + + // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the + // packet must be discarded and the node must send an ICMP Parameter Problem, + // Code 2, message to the packet's source only if the packet's Destination was + // not a multicast address. + IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3 +) + +// IPv6ExtHdrOption is implemented by the various IPv6 extension header options. +type IPv6ExtHdrOption interface { + // UnknownAction returns the action to take in response to an unrecognized + // option. + UnknownAction() IPv6OptionUnknownAction + + // isIPv6ExtHdrOption is used to "lock" this interface so it is not + // implemented by other packages. + isIPv6ExtHdrOption() +} + +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 + +const ( + // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that + // provides 1 byte padding, as outlined in RFC 8200 section 4.2. + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 + + // ipv6PadNExtHdrOptionIdentifier is the identifier for a padding option that + // provides variable length byte padding, as outlined in RFC 8200 section 4.2. + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 +) + +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option +// is malformed. +var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") + +// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension +// header option that is unknown by the parsing utilities. +type IPv6UnknownExtHdrOption struct { + Identifier IPv6ExtHdrOptionIdentifier + Data *buffer.View +} + +// UnknownAction implements IPv6OptionUnknownAction.UnknownAction. +func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(o.Identifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. +func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} + +// Next returns the next option in the options data. +// +// If the next item is not a known extension header option, +// IPv6UnknownExtHdrOption will be returned with the option identifier and data. +// +// The return is of the format (option, done, error). done will be true when +// Next is unable to return anything because the iterator has reached the end of +// the options data, or an error occurred. +func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { + for { + i.optionOffset = i.nextOptionOffset + temp, err := i.reader.ReadByte() + if err != nil { + // If we can't read the first byte of a new option, then we know the + // options buffer has been exhausted and we are done iterating. + return nil, true, nil + } + id := IPv6ExtHdrOptionIdentifier(temp) + + // If the option identifier indicates the option is a Pad1 option, then we + // know the option does not have Length and Data fields. End processing of + // the Pad1 option and continue processing the buffer as a new option. + if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 + continue + } + + length, err := i.reader.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the reader to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) + } + + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Size(); n < int(length) { + // Consume the remaining buffer. + i.reader.TrimFront(i.reader.Size()) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. + i.reader.TrimFront(int(length)) + continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + default: + return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) + } + } else if n != int(length) { + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil + default: + bytes := buffer.NewView(int(length)) + if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + } + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil + } + } +} + +// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options +// extension header. +type IPv6HopByHopOptionsExtHdr struct { + ipv6OptionsExtHdr +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {} + +// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options +// extension header. +type IPv6DestinationOptionsExtHdr struct { + ipv6OptionsExtHdr +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {} + +// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific +// data as outlined in RFC 8200 section 4.4. +type IPv6RoutingExtHdr struct { + Buf *buffer.View +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {} + +// Release implements IPv6PayloadHeader.Release. +func (b IPv6RoutingExtHdr) Release() { + b.Buf.Release() +} + +// SegmentsLeft returns the Segments Left field. +func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 { + return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx] +} + +// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific +// data as outlined in RFC 8200 section 4.5. +// +// Note, the buffer does not include the Next Header and Reserved fields. +type IPv6FragmentExtHdr [6]byte + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {} + +// Release implements IPv6PayloadHeader.Release. +func (IPv6FragmentExtHdr) Release() {} + +// FragmentOffset returns the Fragment Offset field. +// +// This value indicates where the buffer following the Fragment extension header +// starts in the target (reassembled) packet. +func (b IPv6FragmentExtHdr) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift +} + +// More returns the More (M) flag. +// +// This indicates whether any fragments are expected to succeed b. +func (b IPv6FragmentExtHdr) More() bool { + return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0 +} + +// ID returns the Identification field. +// +// This value is used to uniquely identify the packet, between a +// source and destination. +func (b IPv6FragmentExtHdr) ID() uint32 { + return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) +} + +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + +// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. +// +// The IPv6 payload may contain IPv6 extension headers before any upper layer +// data. +// +// Note, between when an IPv6PayloadIterator is obtained and last used, no +// changes to the payload may happen. Doing so may cause undefined and +// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate +// over the first few headers then modify the backing payload so long as the +// IPv6PayloadIterator obtained before modification is no longer used. +type IPv6PayloadIterator struct { + // The identifier of the next header to parse. + nextHdrIdentifier IPv6ExtensionHeaderIdentifier + + payload buffer.Buffer + + // Indicates to the iterator that it should return the remaining payload as a + // raw payload on the next call to Next. + forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset +} + +// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing +// extension headers, or a raw payload if the payload cannot be parsed. The +// iterator takes ownership of the payload. +func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator { + return IPv6PayloadIterator{ + nextHdrIdentifier: nextHdrIdentifier, + payload: payload, + nextOffset: IPv6FixedHeaderSize, + } +} + +// Release frees the resources owned by the iterator. +func (i *IPv6PayloadIterator) Release() { + i.payload.Release() +} + +// AsRawHeader returns the remaining payload of i as a raw header and +// optionally consumes the iterator. +// +// If consume is true, calls to Next after calling AsRawHeader on i will +// indicate that the iterator is done. The returned header takes ownership of +// its payload. +func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { + identifier := i.nextHdrIdentifier + + var buf buffer.Buffer + if consume { + // Since we consume the iterator, we return the payload as is. + buf = i.payload + + // Mark i as done, but keep track of where we were for error reporting. + *i = IPv6PayloadIterator{ + nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + headerOffset: i.headerOffset, + nextOffset: i.nextOffset, + } + } else { + buf = i.payload.Clone() + } + + return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf} +} + +// Next returns the next item in the payload. +// +// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader +// will be returned with the remaining bytes and next header identifier. +// +// The return is of the format (header, done, error). done will be true when +// Next is unable to return anything because the iterator has reached the end of +// the payload, or an error occurred. +func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 + // We could be forced to return i as a raw header when the previous header was + // a fragment extension header as the data following the fragment extension + // header may not be complete. + if i.forceRaw { + return i.AsRawHeader(true /* consume */), false, nil + } + + // Is the header we are parsing a known extension header? + switch i.nextHdrIdentifier { + case IPv6HopByHopOptionsExtHdrIdentifier: + nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil + case IPv6RoutingExtHdrIdentifier: + nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6RoutingExtHdr{view}, false, nil + case IPv6FragmentExtHdrIdentifier: + var data [6]byte + // We ignore the returned bytes because we know the fragment extension + // header specific data will fit in data. + nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) + if err != nil { + return nil, true, err + } + + fragmentExtHdr := IPv6FragmentExtHdr(data) + + // If the packet is not the first fragment, do not attempt to parse anything + // after the fragment extension header as the payload following the fragment + // extension header should not contain any headers; the first fragment must + // hold all the headers up to and including any upper layer headers, as per + // RFC 8200 section 4.5. + if fragmentExtHdr.FragmentOffset() != 0 { + i.forceRaw = true + } + + i.nextHdrIdentifier = nextHdrIdentifier + return fragmentExtHdr, false, nil + case IPv6DestinationOptionsExtHdrIdentifier: + nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil + case IPv6NoNextHeaderIdentifier: + // This indicates the end of the IPv6 payload. + return nil, true, nil + + default: + // The header we are parsing is not a known extension header. Return the + // raw payload. + return i.AsRawHeader(true /* consume */), false, nil + } +} + +// NextHeaderIdentifier returns the identifier of the header next returned by +// it.Next(). +func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier { + return i.nextHdrIdentifier +} + +// nextHeaderData returns the extension header's Next Header field and raw data. +// +// fragmentHdr indicates that the extension header being parsed is the Fragment +// extension header so the Length field should be ignored as it is Reserved +// for the Fragment extension header. +// +// If bytes is not nil, extension header specific data will be read into bytes +// if it has enough capacity. If bytes is provided but does not have enough +// capacity for the data, nextHeaderData will panic. +func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) { + // We ignore the number of bytes read because we know we will only ever read + // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read + // would return io.EOF to indicate that io.Reader has reached the end of the + // payload. + rdr := i.payload.AsBufferReader() + nextHdrIdentifier, err := rdr.ReadByte() + if err != nil { + return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + i.parseOffset++ + + var length uint8 + length, err = rdr.ReadByte() + if err != nil { + if fragmentHdr { + return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + + return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + if fragmentHdr { + length = 0 + } + + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded + if fragmentHdr { + if n := len(bytes); n < bytesLen { + panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier)) + } + if n, err := io.ReadFull(&rdr, bytes); err != nil { + return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) + } + return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil + } + v := buffer.NewView(bytesLen) + if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + v.Release() + return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) + } + return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil +} + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/internal/gtcpip/header/ipv6_fragment.go b/internal/gtcpip/header/ipv6_fragment.go new file mode 100644 index 0000000..49aaca7 --- /dev/null +++ b/internal/gtcpip/header/ipv6_fragment.go @@ -0,0 +1,158 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "github.com/sagernet/sing-tun/internal/gtcpip" +) + +const ( + nextHdrFrag = 0 + fragOff = 2 + more = 3 + idV6 = 4 +) + +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) + +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { + // FragmentOffset is the "fragment offset" field of an IPv6 fragment. + FragmentOffset uint16 + + // M is the "more" field of an IPv6 fragment. + M bool + + // Identification is the "identification" field of an IPv6 fragment. + Identification uint32 +} + +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<= IPv6FragmentHeaderSize +} + +// NextHeader returns the value of the "next header" field of the ipv6 fragment. +func (b IPv6Fragment) NextHeader() uint8 { + return b[nextHdrFrag] +} + +// FragmentOffset returns the "fragment offset" field of the ipv6 fragment. +func (b IPv6Fragment) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[fragOff:]) >> 3 +} + +// More returns the "more" field of the ipv6 fragment. +func (b IPv6Fragment) More() bool { + return b[more]&1 > 0 +} + +// Payload implements Network.Payload. +func (b IPv6Fragment) Payload() []byte { + return b[IPv6FragmentHeaderSize:] +} + +// ID returns the value of the identifier field of the ipv6 fragment. +func (b IPv6Fragment) ID() uint32 { + return binary.BigEndian.Uint32(b[idV6:]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// The functions below have been added only to satisfy the Network interface. + +// Checksum is not supported by IPv6Fragment. +func (b IPv6Fragment) Checksum() uint16 { + panic("not supported") +} + +// SourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SourceAddress() tcpip.Address { + panic("not supported") +} + +// DestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) DestinationAddress() tcpip.Address { + panic("not supported") +} + +// SetSourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetSourceAddress(tcpip.Address) { + panic("not supported") +} + +// SetDestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) { + panic("not supported") +} + +// SetChecksum is not supported by IPv6Fragment. +func (b IPv6Fragment) SetChecksum(uint16) { + panic("not supported") +} + +// TOS is not supported by IPv6Fragment. +func (b IPv6Fragment) TOS() (uint8, uint32) { + panic("not supported") +} + +// SetTOS is not supported by IPv6Fragment. +func (b IPv6Fragment) SetTOS(t uint8, l uint32) { + panic("not supported") +} diff --git a/internal/gtcpip/header/ndp_neighbor_advert.go b/internal/gtcpip/header/ndp_neighbor_advert.go new file mode 100644 index 0000000..7a934cc --- /dev/null +++ b/internal/gtcpip/header/ndp_neighbor_advert.go @@ -0,0 +1,110 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "github.com/sagernet/sing-tun/internal/gtcpip" + +// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will +// only contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.4 for more details. +type NDPNeighborAdvert []byte + +const ( + // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor + // Advertisement message (body of an ICMPv6 packet). + NDPNAMinimumSize = 20 + + // ndpNATargetAddressOffset is the start of the Target Address + // field within an NDPNeighborAdvert. + ndpNATargetAddressOffset = 4 + + // ndpNAOptionsOffset is the start of the NDP options in an + // NDPNeighborAdvert. + ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize + + // ndpNAFlagsOffset is the offset of the flags within an + // NDPNeighborAdvert + ndpNAFlagsOffset = 0 + + // ndpNARouterFlagMask is the mask of the Router Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNARouterFlagMask = (1 << 7) + + // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNASolicitedFlagMask = (1 << 6) + + // ndpNAOverrideFlagMask is the mask of the Override Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNAOverrideFlagMask = (1 << 5) +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborAdvert) TargetAddress() tcpip.Address { + return tcpip.AddrFrom16Slice(b[ndpNATargetAddressOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr.AsSlice()) +} + +// RouterFlag returns the value of the Router Flag field. +func (b NDPNeighborAdvert) RouterFlag() bool { + return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0 +} + +// SetRouterFlag sets the value in the Router Flag field. +func (b NDPNeighborAdvert) SetRouterFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNARouterFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask + } +} + +// SolicitedFlag returns the value of the Solicited Flag field. +func (b NDPNeighborAdvert) SolicitedFlag() bool { + return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0 +} + +// SetSolicitedFlag sets the value in the Solicited Flag field. +func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask + } +} + +// OverrideFlag returns the value of the Override Flag field. +func (b NDPNeighborAdvert) OverrideFlag() bool { + return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0 +} + +// SetOverrideFlag sets the value in the Override Flag field. +func (b NDPNeighborAdvert) SetOverrideFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask + } +} + +// Options returns an NDPOptions of the options body. +func (b NDPNeighborAdvert) Options() NDPOptions { + return NDPOptions(b[ndpNAOptionsOffset:]) +} diff --git a/internal/gtcpip/header/ndp_neighbor_solicit.go b/internal/gtcpip/header/ndp_neighbor_solicit.go new file mode 100644 index 0000000..61d61a8 --- /dev/null +++ b/internal/gtcpip/header/ndp_neighbor_solicit.go @@ -0,0 +1,52 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "github.com/sagernet/sing-tun/internal/gtcpip" + +// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only +// contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.3 for more details. +type NDPNeighborSolicit []byte + +const ( + // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor + // Solicitation message (body of an ICMPv6 packet). + NDPNSMinimumSize = 20 + + // ndpNSTargetAddessOffset is the start of the Target Address + // field within an NDPNeighborSolicit. + ndpNSTargetAddessOffset = 4 + + // ndpNSOptionsOffset is the start of the NDP options in an + // NDPNeighborSolicit. + ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborSolicit) TargetAddress() tcpip.Address { + return tcpip.AddrFrom16Slice(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr.AsSlice()) +} + +// Options returns an NDPOptions of the options body. +func (b NDPNeighborSolicit) Options() NDPOptions { + return NDPOptions(b[ndpNSOptionsOffset:]) +} diff --git a/internal/gtcpip/header/ndp_options.go b/internal/gtcpip/header/ndp_options.go new file mode 100644 index 0000000..ba29339 --- /dev/null +++ b/internal/gtcpip/header/ndp_options.go @@ -0,0 +1,1073 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "time" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing/common" +) + +// ndpOptionIdentifier is an NDP option type identifier. +type ndpOptionIdentifier uint8 + +const ( + // ndpSourceLinkLayerAddressOptionType is the type of the Source Link Layer + // Address option, as per RFC 4861 section 4.6.1. + ndpSourceLinkLayerAddressOptionType ndpOptionIdentifier = 1 + + // ndpTargetLinkLayerAddressOptionType is the type of the Target Link Layer + // Address option, as per RFC 4861 section 4.6.1. + ndpTargetLinkLayerAddressOptionType ndpOptionIdentifier = 2 + + // ndpPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + ndpPrefixInformationType ndpOptionIdentifier = 3 + + // ndpNonceOptionType is the type of the Nonce option, as per + // RFC 3971 section 5.3.2. + ndpNonceOptionType ndpOptionIdentifier = 14 + + // ndpRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + ndpRecursiveDNSServerOptionType ndpOptionIdentifier = 25 + + // ndpDNSSearchListOptionType is the type of the DNS Search List option, + // as per RFC 8106 section 5.2. + ndpDNSSearchListOptionType ndpOptionIdentifier = 31 +) + +const ( + // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer + // Address option for an Ethernet address. + NDPLinkLayerAddressSize = 8 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 because 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = 1 << 7 + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server + // option's body size when it contains at least one IPv6 address, as per + // RFC 8106 section 5.3.1. + minNDPRecursiveDNSServerBodySize = 22 + + // ndpDNSSearchListLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPDNSSearchList. + ndpDNSSearchListLifetimeOffset = 2 + + // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list + // domain names within an NDPDNSSearchList. + ndpDNSSearchListDomainNamesOffset = 6 + + // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's + // body size when it contains at least one domain name, as per RFC 8106 + // section 5.3.1. + minNDPDNSSearchListBodySize = 14 + + // maxDomainNameLabelLength is the maximum length of a domain name + // label, as per RFC 1035 section 3.1. + maxDomainNameLabelLength = 63 + + // maxDomainNameLength is the maximum length of a domain name, including + // label AND label length octet, as per RFC 1035 section 3.1. + maxDomainNameLength = 255 + + // lengthByteUnits is the multiplier factor for the Length field of an + // NDP option. That is, the length field for NDP options is in units of + // 8 octets, as per RFC 4861 section 4.6. + lengthByteUnits = 8 + + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. + NDPInfiniteLifetime = time.Second * math.MaxUint32 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + opts *bytes.Buffer +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occurred. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if i.opts.Len() == 0 { + return nil, true, nil + } + + // Get the Type field. + temp, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the buffer to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) + } + kind := ndpOptionIdentifier(temp) + + // Get the Length field. + length, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err)) + } + + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF) + } + + // This would indicate an erroneous NDP option as the Length field should + // never be 0. + if length == 0 { + return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader) + } + + // Get the body. + numBytes := int(length) * lengthByteUnits + numBodyBytes := numBytes - 2 + body := i.opts.Next(numBodyBytes) + if len(body) < numBodyBytes { + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF) + } + + switch kind { + case ndpSourceLinkLayerAddressOptionType: + return NDPSourceLinkLayerAddressOption(body), false, nil + + case ndpTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case ndpNonceOptionType: + return NDPNonceOption(body), false, nil + + case ndpRouteInformationType: + if numBodyBytes > ndpRouteInformationMaxLength { + return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody) + } + opt := NDPRouteInformation(body) + if err := opt.hasError(); err != nil { + return nil, true, err + } + + return opt, false, nil + + case ndpPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody) + } + + return NDPPrefixInformation(body), false, nil + + case ndpRecursiveDNSServerOptionType: + opt := NDPRecursiveDNSServer(body) + if err := opt.checkAddresses(); err != nil { + return nil, true, err + } + + return opt, false, nil + + case ndpDNSSearchListOptionType: + opt := NDPDNSSearchList(body) + if err := opt.checkDomainNames(); err != nil { + return nil, true, err + } + + return opt, false, nil + + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + +// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. +type NDPOptions []byte + +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + if check { + it2 := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + for { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + +// Serialize serializes the provided list of NDP options into b. +// +// Note, b must be of sufficient size to hold all the options in s. See +// NDPOptionsSerializer.Length for details on the getting the total size +// of a serialized NDPOptionsSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { + done := 0 + + for _, o := range s { + l := paddedLength(o) + + if l == 0 { + continue + } + + b[0] = byte(o.kind()) + + // We know this safe because paddedLength would have returned + // 0 if o had an invalid length (> 255 * lengthByteUnits). + b[1] = uint8(l / lengthByteUnits) + + // Serialize NDP option body. + used := o.serializeInto(b[2:]) + + // Zero out remaining (padding) bytes, if any exists. + if used+2 < l { + common.ClearArray(b[used+2 : l]) + } + + b = b[l:] + done += l + } + + return done +} + +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + fmt.Stringer + + // kind returns the type of the receiver. + kind() ndpOptionIdentifier + + // length returns the length of the body of the receiver, in bytes. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) int +} + +// paddedLength returns the length of o, in bytes, with any padding bytes, if +// required. +func paddedLength(o NDPOption) int { + l := o.length() + + if l == 0 { + return 0 + } + + // Length excludes the 2 Type and Length bytes. + l += 2 + + // Add extra bytes if needed to make sure the option is + // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1 + // to l and then stripping off the last few LSBits from l. This will + // make sure that l is rounded up to the nearest unit of + // lengthByteUnits. This works since lengthByteUnits is a power of 2 + // (= 8). + mask := lengthByteUnits - 1 + l += mask + l &^= mask + + if l/lengthByteUnits > 255 { + // Should never happen because an option can only have a max + // value of 255 for its Length field, so just return 0 so this + // option does not get serialized. + // + // Returning 0 here will make sure that this option does not get + // serialized when NDPOptions.Serialize is called with the + // NDPOptionsSerializer that holds this option, effectively + // skipping this option during serialization. Also note that + // a value of zero for the Length field in an NDP option is + // invalid so this is another sign to the caller that this NDP + // option is malformed, as per RFC 4861 section 4.6. + return 0 + } + + return l +} + +// NDPOptionsSerializer is a serializer for NDP options. +type NDPOptionsSerializer []NDPOption + +// Length returns the total number of bytes required to serialize. +func (b NDPOptionsSerializer) Length() int { + l := 0 + + for _, o := range b { + l += paddedLength(o) + } + + return l +} + +// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPNonceOption []byte + +// kind implements NDPOption. +func (o NDPNonceOption) kind() ndpOptionIdentifier { + return ndpNonceOptionType +} + +// length implements NDPOption. +func (o NDPNonceOption) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPNonceOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPNonceOption) String() string { + return fmt.Sprintf("%T(%x)", o, []byte(o)) +} + +// Nonce returns the nonce value this option holds. +func (o NDPNonceOption) Nonce() []byte { + return o +} + +// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPSourceLinkLayerAddressOption tcpip.LinkAddress + +// kind implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpSourceLinkLayerAddressOptionType +} + +// length implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPSourceLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPTargetLinkLayerAddressOption tcpip.LinkAddress + +// kind implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) kind() ndpOptionIdentifier { + return ndpTargetLinkLayerAddressOptionType +} + +// length implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPTargetLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// kind implements NDPOption. +func (o NDPPrefixInformation) kind() ndpOptionIdentifier { + return ndpPrefixInformationType +} + +// length implements NDPOption. +func (o NDPPrefixInformation) length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + common.ClearArray(reserved2) + + return used +} + +// String implements fmt.Stringer. +func (o NDPPrefixInformation) String() string { + return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", + o, + o.OnLinkFlag(), + o.AutonomousAddressConfigurationFlag(), + o.PreferredLifetime(), + o.ValidLifetime(), + o.Subnet()) +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.AddrFrom16Slice(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// kind implements NDPOption. +func (NDPRecursiveDNSServer) kind() ndpOptionIdentifier { + return ndpRecursiveDNSServerOptionType +} + +// length implements NDPOption. +func (o NDPRecursiveDNSServer) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + common.ClearArray(b[0:ndpRecursiveDNSServerLifetimeOffset]) + + return used +} + +// String implements fmt.Stringer. +func (o NDPRecursiveDNSServer) String() string { + lt := o.Lifetime() + addrs, err := o.Addresses() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt) +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) { + var addrs []tcpip.Address + return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) }) +} + +// checkAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and returns any error it encounters. +func (o NDPRecursiveDNSServer) checkAddresses() error { + return o.iterAddresses(nil) +} + +// iterAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and calls a function with each valid unicast IPv6 address. +// +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { + if l := len(o); l < minNDPRecursiveDNSServerBodySize { + return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF) + } + + o = o[ndpRecursiveDNSServerAddressesOffset:] + l := len(o) + if l%IPv6AddressSize != 0 { + return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody) + } + + for i := 0; len(o) != 0; i++ { + addr := tcpip.AddrFrom16Slice(o[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody) + } + + if fn != nil { + fn(addr) + } + + o = o[IPv6AddressSize:] + } + + return nil +} + +// NDPDNSSearchList is the NDP DNS Search List option, as defined by +// RFC 8106 section 5.2. +type NDPDNSSearchList []byte + +// kind implements NDPOption. +func (o NDPDNSSearchList) kind() ndpOptionIdentifier { + return ndpDNSSearchListOptionType +} + +// length implements NDPOption. +func (o NDPDNSSearchList) length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPDNSSearchList) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + common.ClearArray(b[0:ndpDNSSearchListLifetimeOffset]) + + return used +} + +// String implements fmt.Stringer. +func (o NDPDNSSearchList) String() string { + lt := o.Lifetime() + domainNames, err := o.DomainNames() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt) +} + +// Lifetime returns the length of time that the DNS search list of domain names +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the domain names should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPDNSSearchList) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:])) +} + +// DomainNames returns a DNS search list of domain names. +// +// DomainNames will parse the backing buffer as outlined by RFC 1035 section +// 3.1 and return a list of strings, with all domain names in lower case. +func (o NDPDNSSearchList) DomainNames() ([]string, error) { + var domainNames []string + return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) }) +} + +// checkDomainNames iterates over the domain names in an NDP DNS Search List +// option and returns any error it encounters. +func (o NDPDNSSearchList) checkDomainNames() error { + return o.iterDomainNames(nil) +} + +// iterDomainNames iterates over the domain names in an NDP DNS Search List +// option and calls a function with each valid domain name. +func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error { + if l := len(o); l < minNDPDNSSearchListBodySize { + return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF) + } + + var searchList bytes.Reader + searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:]) + + var scratch [maxDomainNameLength]byte + domainName := bytes.NewBuffer(scratch[:]) + + // Parse the domain names, as per RFC 1035 section 3.1. + for searchList.Len() != 0 { + domainName.Reset() + + // Parse a label within a domain name, as per RFC 1035 section 3.1. + for { + // The first byte is the label length. + labelLenByte, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected + // once we start parsing a domain name; we expect the buffer to contain + // enough bytes for the whole domain name. + return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF) + } + labelLen := int(labelLenByte) + + // A zero-length label implies the end of a domain name. + if labelLen == 0 { + // If the domain name is empty or we have no callback function, do + // nothing further with the current domain name. + if domainName.Len() == 0 || fn == nil { + break + } + + // Ignore the trailing period in the parsed domain name. + domainName.Truncate(domainName.Len() - 1) + fn(domainName.String()) + break + } + + // The label's length must not exceed the maximum length for a label. + if labelLen > maxDomainNameLabelLength { + return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody) + } + + // The label (and trailing period) must not make the domain name too long. + if labelLen+1 > domainName.Cap()-domainName.Len() { + return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody) + } + + // Copy the label and add a trailing period. + for i := 0; i < labelLen; i++ { + b, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err)) + } + + return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF) + } + + // As per RFC 1035 section 2.3.1: + // 1) the label must only contain ASCII include letters, digits and + // hyphens + // 2) the first character in a label must be a letter + // 3) the last letter in a label must be a letter or digit + + if !isLetter(b) { + if i == 0 { + return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + + if b == '-' { + if i == labelLen-1 { + return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody) + } + } else if !isDigit(b) { + return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + } + + // If b is an upper case character, make it lower case. + if isUpperLetter(b) { + b = b - 'A' + 'a' + } + + if err := domainName.WriteByte(b); err != nil { + panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err)) + } + } + if err := domainName.WriteByte('.'); err != nil { + panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err)) + } + } + } + + return nil +} + +func isLetter(b byte) bool { + return b >= 'a' && b <= 'z' || isUpperLetter(b) +} + +func isUpperLetter(b byte) bool { + return b >= 'A' && b <= 'Z' +} + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} + +// As per RFC 4191 section 2.3, +// +// 2.3. Route Information Option +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Length | Prefix Length |Resvd|Prf|Resvd| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Route Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Prefix (Variable Length) | +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Fields: +// +// Type 24 +// +// +// Length 8-bit unsigned integer. The length of the option +// (including the Type and Length fields) in units of 8 +// octets. The Length field is 1, 2, or 3 depending on the +// Prefix Length. If Prefix Length is greater than 64, then +// Length must be 3. If Prefix Length is greater than 0, +// then Length must be 2 or 3. If Prefix Length is zero, +// then Length must be 1, 2, or 3. +const ( + ndpRouteInformationType = ndpOptionIdentifier(24) + ndpRouteInformationMaxLength = 22 + + ndpRouteInformationPrefixLengthIdx = 0 + ndpRouteInformationFlagsIdx = 1 + ndpRouteInformationPrfShift = 3 + ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift + ndpRouteInformationRouteLifetimeIdx = 2 + ndpRouteInformationRoutePrefixIdx = 6 +) + +// NDPRouteInformation is the NDP Router Information option, as defined by +// RFC 4191 section 2.3. +type NDPRouteInformation []byte + +func (NDPRouteInformation) kind() ndpOptionIdentifier { + return ndpRouteInformationType +} + +func (o NDPRouteInformation) length() int { + return len(o) +} + +func (o NDPRouteInformation) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPRouteInformation) String() string { + return fmt.Sprintf("%T", o) +} + +// PrefixLength returns the length of the prefix. +func (o NDPRouteInformation) PrefixLength() uint8 { + return o[ndpRouteInformationPrefixLengthIdx] +} + +// RoutePreference returns the preference of the route over other routes to the +// same destination but through a different router. +func (o NDPRouteInformation) RoutePreference() NDPRoutePreference { + return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift) +} + +// RouteLifetime returns the lifetime of the route. +// +// Note, a value of 0 implies the route is now invalid and a value of +// infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPRouteInformation) RouteLifetime() time.Duration { + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:])) +} + +// Prefix returns the prefix of the destination subnet this route is for. +func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) { + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max) + } + + prefix := o[ndpRouteInformationRoutePrefixIdx:] + var addrBytes [IPv6AddressSize]byte + if n := copy(addrBytes[:], prefix); n != len(prefix) { + panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix))) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16(addrBytes), + PrefixLen: prefixLength, + }.Subnet(), nil +} + +func (o NDPRouteInformation) hasError() error { + l := len(o) + if l < ndpRouteInformationRoutePrefixIdx { + return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody) + } + + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody) + } + + // Length 8-bit unsigned integer. The length of the option + // (including the Type and Length fields) in units of 8 + // octets. The Length field is 1, 2, or 3 depending on the + // Prefix Length. If Prefix Length is greater than 64, then + // Length must be 3. If Prefix Length is greater than 0, + // then Length must be 2 or 3. If Prefix Length is zero, + // then Length must be 1, 2, or 3. + l += 2 // Add 2 bytes for the type and length bytes. + lengthField := l / lengthByteUnits + if prefixLength > 64 { + if lengthField != 3 { + return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if prefixLength > 0 { + if lengthField != 2 && lengthField != 3 { + return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if lengthField == 0 || lengthField > 3 { + return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody) + } + + return nil +} diff --git a/internal/gtcpip/header/ndp_router_advert.go b/internal/gtcpip/header/ndp_router_advert.go new file mode 100644 index 0000000..e2456c0 --- /dev/null +++ b/internal/gtcpip/header/ndp_router_advert.go @@ -0,0 +1,204 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" +) + +var _ fmt.Stringer = NDPRoutePreference(0) + +// NDPRoutePreference is the preference values for default routers or +// more-specific routes. +// +// As per RFC 4191 section 2.1, +// +// Default router preferences and preferences for more-specific routes +// are encoded the same way. +// +// Preference values are encoded as a two-bit signed integer, as +// follows: +// +// 01 High +// 00 Medium (default) +// 11 Low +// 10 Reserved - MUST NOT be sent +// +// Note that implementations can treat the value as a two-bit signed +// integer. +// +// Having just three values reinforces that they are not metrics and +// more values do not appear to be necessary for reasonable scenarios. +type NDPRoutePreference uint8 + +const ( + // HighRoutePreference indicates a high preference, as per + // RFC 4191 section 2.1. + HighRoutePreference NDPRoutePreference = 0b01 + + // MediumRoutePreference indicates a medium preference, as per + // RFC 4191 section 2.1. + // + // This is the default preference value. + MediumRoutePreference = 0b00 + + // LowRoutePreference indicates a low preference, as per + // RFC 4191 section 2.1. + LowRoutePreference = 0b11 + + // ReservedRoutePreference is a reserved preference value, as per + // RFC 4191 section 2.1. + // + // It MUST NOT be sent. + ReservedRoutePreference = 0b10 +) + +// String implements fmt.Stringer. +func (p NDPRoutePreference) String() string { + switch p { + case HighRoutePreference: + return "HighRoutePreference" + case MediumRoutePreference: + return "MediumRoutePreference" + case LowRoutePreference: + return "LowRoutePreference" + case ReservedRoutePreference: + return "ReservedRoutePreference" + default: + return fmt.Sprintf("NDPRoutePreference(%d)", p) + } +} + +// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details. +type NDPRouterAdvert []byte + +// As per RFC 4191 section 2.2, +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Reachable Time | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Retrans Timer | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options ... +// +-+-+-+-+-+-+-+-+-+-+-+- +const ( + // NDPRAMinimumSize is the minimum size of a valid NDP Router + // Advertisement message (body of an ICMPv6 packet). + NDPRAMinimumSize = 12 + + // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field + // within an NDPRouterAdvert. + ndpRACurrHopLimitOffset = 0 + + // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags + // within an NDPRouterAdvert. + ndpRAFlagsOffset = 1 + + // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address + // Configuration flag within the bit-field/flags byte of an + // NDPRouterAdvert. + ndpRAManagedAddrConfFlagMask = (1 << 7) + + // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag + // within the bit-field/flags byte of an NDPRouterAdvert. + ndpRAOtherConfFlagMask = (1 << 6) + + // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceShift = 3 + + // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift) + + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime + // field within an NDPRouterAdvert. + ndpRARouterLifetimeOffset = 2 + + // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time + // field within an NDPRouterAdvert. + ndpRAReachableTimeOffset = 4 + + // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer + // field within an NDPRouterAdvert. + ndpRARetransTimerOffset = 8 + + // ndpRAOptionsOffset is the start of the NDP options in an + // NDPRouterAdvert. + ndpRAOptionsOffset = 12 +) + +// CurrHopLimit returns the value of the Curr Hop Limit field. +func (b NDPRouterAdvert) CurrHopLimit() uint8 { + return b[ndpRACurrHopLimitOffset] +} + +// ManagedAddrConfFlag returns the value of the Managed Address Configuration +// flag. +func (b NDPRouterAdvert) ManagedAddrConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0 +} + +// OtherConfFlag returns the value of the Other Configuration flag. +func (b NDPRouterAdvert) OtherConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 +} + +// DefaultRouterPreference returns the Default Router Preference field. +func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference { + return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift) +} + +// RouterLifetime returns the lifetime associated with the default router. A +// value of 0 means the source of the Router Advertisement is not a default +// router and SHOULD NOT appear on the default router list. Note, a value of 0 +// only means that the router should not be used as a default router, it does +// not apply to other information contained in the Router Advertisement. +func (b NDPRouterAdvert) RouterLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.2. + return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:])) +} + +// ReachableTime returns the time that a node assumes a neighbor is reachable +// after having received a reachability confirmation. A value of 0 means +// that it is unspecified by the source of the Router Advertisement message. +func (b NDPRouterAdvert) ReachableTime() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:])) +} + +// RetransTimer returns the time between retransmitted Neighbor Solicitation +// messages. A value of 0 means that it is unspecified by the source of the +// Router Advertisement message. +func (b NDPRouterAdvert) RetransTimer() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:])) +} + +// Options returns an NDPOptions of the options body. +func (b NDPRouterAdvert) Options() NDPOptions { + return NDPOptions(b[ndpRAOptionsOffset:]) +} diff --git a/internal/gtcpip/header/ndp_router_solicit.go b/internal/gtcpip/header/ndp_router_solicit.go new file mode 100644 index 0000000..5ca2e5c --- /dev/null +++ b/internal/gtcpip/header/ndp_router_solicit.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.1 for more details. +type NDPRouterSolicit []byte + +const ( + // NDPRSMinimumSize is the minimum size of a valid NDP Router + // Solicitation message (body of an ICMPv6 packet). + NDPRSMinimumSize = 4 + + // ndpRSOptionsOffset is the start of the NDP options in an + // NDPRouterSolicit. + ndpRSOptionsOffset = 4 +) + +// Options returns an NDPOptions of the options body. +func (b NDPRouterSolicit) Options() NDPOptions { + return NDPOptions(b[ndpRSOptionsOffset:]) +} diff --git a/internal/gtcpip/header/ndpoptionidentifier_string.go b/internal/gtcpip/header/ndpoptionidentifier_string.go new file mode 100644 index 0000000..55ab1d7 --- /dev/null +++ b/internal/gtcpip/header/ndpoptionidentifier_string.go @@ -0,0 +1,58 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT. + +package header + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ndpSourceLinkLayerAddressOptionType-1] + _ = x[ndpTargetLinkLayerAddressOptionType-2] + _ = x[ndpPrefixInformationType-3] + _ = x[ndpNonceOptionType-14] + _ = x[ndpRecursiveDNSServerOptionType-25] + _ = x[ndpDNSSearchListOptionType-31] +} + +const ( + _ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType" + _ndpOptionIdentifier_name_1 = "ndpNonceOptionType" + _ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType" + _ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType" +) + +var ( + _ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} +) + +func (i ndpOptionIdentifier) String() string { + switch { + case 1 <= i && i <= 3: + i -= 1 + return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]] + case i == 14: + return _ndpOptionIdentifier_name_1 + case i == 25: + return _ndpOptionIdentifier_name_2 + case i == 31: + return _ndpOptionIdentifier_name_3 + default: + return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/internal/gtcpip/header/netip.go b/internal/gtcpip/header/netip.go new file mode 100644 index 0000000..a3502ab --- /dev/null +++ b/internal/gtcpip/header/netip.go @@ -0,0 +1,35 @@ +package header + +import "net/netip" + +func (b IPv4) SourceAddr() netip.Addr { + return netip.AddrFrom4([4]byte(b[srcAddr : srcAddr+IPv4AddressSize])) +} + +func (b IPv4) DestinationAddr() netip.Addr { + return netip.AddrFrom4([4]byte(b[dstAddr : dstAddr+IPv4AddressSize])) +} + +func (b IPv4) SetSourceAddr(addr netip.Addr) { + copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.AsSlice()) +} + +func (b IPv4) SetDestinationAddr(addr netip.Addr) { + copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.AsSlice()) +} + +func (b IPv6) SourceAddr() netip.Addr { + return netip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize])) +} + +func (b IPv6) DestinationAddr() netip.Addr { + return netip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize])) +} + +func (b IPv6) SetSourceAddr(addr netip.Addr) { + copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice()) +} + +func (b IPv6) SetDestinationAddr(addr netip.Addr) { + copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice()) +} diff --git a/internal/gtcpip/header/tcp.go b/internal/gtcpip/header/tcp.go new file mode 100644 index 0000000..5855253 --- /dev/null +++ b/internal/gtcpip/header/tcp.go @@ -0,0 +1,727 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/seqnum" + + "github.com/google/btree" +) + +// These constants are the offsets of the respective fields in the TCP header. +const ( + TCPSrcPortOffset = 0 + TCPDstPortOffset = 2 + TCPSeqNumOffset = 4 + TCPAckNumOffset = 8 + TCPDataOffset = 12 + TCPFlagsOffset = 13 + TCPWinSizeOffset = 14 + TCPChecksumOffset = 16 + TCPUrgentPtrOffset = 18 +) + +const ( + // MaxWndScale is maximum allowed window scaling, as described in + // RFC 1323, section 2.3, page 11. + MaxWndScale = 14 + + // TCPMaxSACKBlocks is the maximum number of SACK blocks that can + // be encoded in a TCP option field. + TCPMaxSACKBlocks = 4 +) + +// TCPFlags is the dedicated type for TCP flags. +type TCPFlags uint8 + +// Intersects returns true iff there are flags common to both f and o. +func (f TCPFlags) Intersects(o TCPFlags) bool { + return f&o != 0 +} + +// Contains returns true iff all the flags in o are contained within f. +func (f TCPFlags) Contains(o TCPFlags) bool { + return f&o == o +} + +// String implements Stringer.String. +func (f TCPFlags) String() string { + flagsStr := []byte("FSRPAUEC") + for i := range flagsStr { + if f&(1<> 4) * 4 +} + +// Payload returns the data in the TCP packet. +func (b TCP) Payload() []byte { + return b[b.DataOffset():] +} + +// Flags returns the flags field of the TCP header. +func (b TCP) Flags() TCPFlags { + return TCPFlags(b[TCPFlagsOffset]) +} + +// WindowSize returns the "window size" field of the TCP header. +func (b TCP) WindowSize() uint16 { + return binary.BigEndian.Uint16(b[TCPWinSizeOffset:]) +} + +// Checksum returns the "checksum" field of the TCP header. +func (b TCP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) +} + +// UrgentPointer returns the "urgent pointer" field of the TCP header. +func (b TCP) UrgentPointer() uint16 { + return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) +} + +// SetSourcePort sets the "source port" field of the TCP header. +func (b TCP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) +} + +// SetDestinationPort sets the "destination port" field of the TCP header. +func (b TCP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port) +} + +// SetChecksum sets the checksum field of the TCP header. +func (b TCP) SetChecksum(xsum uint16) { + checksum.Put(b[TCPChecksumOffset:], xsum) +} + +// SetDataOffset sets the data offset field of the TCP header. headerLen should +// be the length of the TCP header in bytes. +func (b TCP) SetDataOffset(headerLen uint8) { + b[TCPDataOffset] = (headerLen / 4) << 4 +} + +// SetSequenceNumber sets the sequence number field of the TCP header. +func (b TCP) SetSequenceNumber(seqNum uint32) { + binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) +} + +// SetAckNumber sets the ack number field of the TCP header. +func (b TCP) SetAckNumber(ackNum uint32) { + binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) +} + +// SetFlags sets the flags field of the TCP header. +func (b TCP) SetFlags(flags uint8) { + b[TCPFlagsOffset] = flags +} + +// SetWindowSize sets the window size field of the TCP header. +func (b TCP) SetWindowSize(rcvwnd uint16) { + binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) +} + +// SetUrgentPointer sets the window size field of the TCP header. +func (b TCP) SetUrgentPointer(urgentPointer uint16) { + binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) +} + +// CalculateChecksum calculates the checksum of the TCP segment. +// partialChecksum is the checksum of the network-layer pseudo-header +// and the checksum of the segment data. +func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { + // Calculate the rest of the checksum. + return checksum.Checksum(b[:b.DataOffset()], partialChecksum) +} + +// IsChecksumValid returns true iff the TCP header's checksum is valid. +func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { + xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength) + xsum = checksum.Combine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Options returns a slice that holds the unparsed TCP options in the segment. +func (b TCP) Options() []byte { + return b[TCPMinimumSize:b.DataOffset()] +} + +// ParsedOptions returns a TCPOptions structure which parses and caches the TCP +// option values in the TCP segment. NOTE: Invoking this function repeatedly is +// expensive as it reparses the options on each invocation. +func (b TCP) ParsedOptions() TCPOptions { + return ParseTCPOptions(b.Options()) +} + +func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { + binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq) + binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack) + b[TCPFlagsOffset] = uint8(flags) + binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) +} + +// Encode encodes all the fields of the TCP header. +func (b TCP) Encode(t *TCPFields) { + b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) + b.SetSourcePort(t.SrcPort) + b.SetDestinationPort(t.DstPort) + b.SetDataOffset(t.DataOffset) + b.SetChecksum(t.Checksum) + b.SetUrgentPointer(t.UrgentPointer) +} + +// EncodePartial updates a subset of the fields of the TCP header. It is useful +// in cases when similar segments are produced. +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { + // Add the total length and "flags" field contributions to the checksum. + // We don't use the flags field directly from the header because it's a + // one-byte field with an odd offset, so it would be accounted for + // incorrectly by the Checksum routine. + tmp := make([]byte, 4) + binary.BigEndian.PutUint16(tmp, length) + binary.BigEndian.PutUint16(tmp[2:], uint16(flags)) + xsum := checksum.Checksum(tmp, partialChecksum) + + // Encode the passed-in fields. + b.encodeSubset(seqnum, acknum, flags, rcvwnd) + + // Add the contributions of the passed-in fields to the checksum. + xsum = checksum.Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], xsum) + xsum = checksum.Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], xsum) + + // Encode the checksum. + b.SetChecksum(^xsum) +} + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + +// ParseSynOptions parses the options received in a SYN segment and returns the +// relevant ones. opts should point to the option part of the TCP header. +func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { + limit := len(opts) + + synOpts := TCPSynOptions{ + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + MSS: TCPDefaultMSS, + // If no window scale option is specified, WS in options is + // returned as -1; this is because the absence of the option + // indicates that the we cannot use window scaling on the + // receive end either. + WS: -1, + } + + for i := 0; i < limit; { + switch opts[i] { + case TCPOptionEOL: + i = limit + case TCPOptionNOP: + i++ + case TCPOptionMSS: + if i+4 > limit || opts[i+1] != 4 { + return synOpts + } + mss := uint16(opts[i+2])<<8 | uint16(opts[i+3]) + if mss == 0 { + return synOpts + } + synOpts.MSS = mss + if mss < TCPMinimumSendMSS { + synOpts.MSS = TCPMinimumSendMSS + } + i += 4 + + case TCPOptionWS: + if i+3 > limit || opts[i+1] != 3 { + return synOpts + } + ws := int(opts[i+2]) + if ws > MaxWndScale { + ws = MaxWndScale + } + synOpts.WS = ws + i += 3 + + case TCPOptionTS: + if i+10 > limit || opts[i+1] != 10 { + return synOpts + } + synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:]) + if isAck { + // If the segment is a SYN-ACK then store the Timestamp Echo Reply + // in the segment. + synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:]) + } + synOpts.TS = true + i += 10 + case TCPOptionSACKPermitted: + if i+2 > limit || opts[i+1] != 2 { + return synOpts + } + synOpts.SACKPermitted = true + i += 2 + + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return synOpts + } + l := int(opts[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return synOpts + } + i += l + } + } + + return synOpts +} + +// ParseTCPOptions extracts and stores all known options in the provided byte +// slice in a TCPOptions structure. +func ParseTCPOptions(b []byte) TCPOptions { + opts := TCPOptions{} + limit := len(b) + for i := 0; i < limit; { + switch b[i] { + case TCPOptionEOL: + i = limit + case TCPOptionNOP: + i++ + case TCPOptionTS: + if i+10 > limit || (b[i+1] != 10) { + return opts + } + opts.TS = true + opts.TSVal = binary.BigEndian.Uint32(b[i+2:]) + opts.TSEcr = binary.BigEndian.Uint32(b[i+6:]) + i += 10 + case TCPOptionSACK: + if i+2 > limit { + // Malformed SACK block, just return and stop parsing. + return opts + } + sackOptionLen := int(b[i+1]) + if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { + // Malformed SACK block, just return and stop parsing. + return opts + } + numBlocks := (sackOptionLen - 2) / 8 + opts.SACKBlocks = []SACKBlock{} + for j := 0; j < numBlocks; j++ { + start := binary.BigEndian.Uint32(b[i+2+j*8:]) + end := binary.BigEndian.Uint32(b[i+2+j*8+4:]) + opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{ + Start: seqnum.Value(start), + End: seqnum.Value(end), + }) + } + i += sackOptionLen + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return opts + } + l := int(b[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return opts + } + i += l + } + } + return opts +} + +// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in +// the supplied buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeMSSOption(mss uint32, b []byte) int { + if len(b) < TCPOptionMSSLength { + return 0 + } + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength +} + +// EncodeWSOption encodes the WS TCP option with the WS value in the +// provided buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeWSOption(ws int, b []byte) int { + if len(b) < TCPOptionWSLength { + return 0 + } + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) + return int(b[1]) +} + +// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp +// option into the provided buffer. If the buffer is smaller than expected it +// just returns without encoding anything. It returns the number of bytes +// written to the provided buffer. +func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { + if len(b) < TCPOptionTSLength { + return 0 + } + b[0], b[1] = TCPOptionTS, TCPOptionTSLength + binary.BigEndian.PutUint32(b[2:], tsVal) + binary.BigEndian.PutUint32(b[6:], tsEcr) + return int(b[1]) +} + +// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided +// buffer. If the buffer is smaller than required it just returns without +// encoding anything. It returns the number of bytes written to the provided +// buffer. +func EncodeSACKPermittedOption(b []byte) int { + if len(b) < TCPOptionSackPermittedLength { + return 0 + } + + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength + return int(b[1]) +} + +// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block +// in the provided slice. It tries to fit in as many blocks as possible based on +// number of bytes available in the provided buffer. It returns the number of +// bytes written to the provided buffer. +func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int { + if len(sackBlocks) == 0 { + return 0 + } + l := len(sackBlocks) + if l > TCPMaxSACKBlocks { + l = TCPMaxSACKBlocks + } + if ll := (len(b) - 2) / 8; ll < l { + l = ll + } + if l == 0 { + // There is not enough space in the provided buffer to add + // any SACK blocks. + return 0 + } + b[0] = TCPOptionSACK + b[1] = byte(l*8 + 2) + for i := 0; i < l; i++ { + binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start)) + binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End)) + } + return int(b[1]) +} + +// EncodeNOP adds an explicit NOP to the option list. +func EncodeNOP(b []byte) int { + if len(b) == 0 { + return 0 + } + b[0] = TCPOptionNOP + return 1 +} + +// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align +// the option buffer. It adds padding bytes after the offset specified and +// returns the number of padding bytes added. The passed in options slice +// must have space for the padding bytes. +func AddTCPOptionPadding(options []byte, offset int) int { + paddingToAdd := -offset & 3 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + options[i] = TCPOptionNOP + } + return paddingToAdd +} + +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receive window. + // segSeq < rcvAcc is more correct according to RFC, however, Linux does it + // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior + // as Linux. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc) +} + +// TCPValid returns true if the pkt has a valid TCP header. It checks whether: +// - The data offset is too small. +// - The data offset is too large. +// - The checksum is invalid. +// +// TCPValid corresponds to net/netfilter/nf_conntrack_proto_tcp.c:tcp_error. +func TCPValid(hdr TCP, payloadChecksum func() uint16, payloadSize uint16, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (csum uint16, csumValid, ok bool) { + if offset := int(hdr.DataOffset()); offset < TCPMinimumSize || offset > len(hdr) { + return + } + + if skipChecksumValidation { + csumValid = true + } else { + csum = hdr.Checksum() + csumValid = hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum(), payloadSize) + } + return csum, csumValid, true +} diff --git a/internal/gtcpip/header/udp.go b/internal/gtcpip/header/udp.go new file mode 100644 index 0000000..080a97f --- /dev/null +++ b/internal/gtcpip/header/udp.go @@ -0,0 +1,195 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "math" + + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" +) + +const ( + udpSrcPort = 0 + udpDstPort = 2 + udpLength = 4 + udpChecksum = 6 +) + +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + +// UDPFields contains the fields of a UDP packet. It is used to describe the +// fields of a packet that needs to be encoded. +type UDPFields struct { + // SrcPort is the "source port" field of a UDP packet. + SrcPort uint16 + + // DstPort is the "destination port" field of a UDP packet. + DstPort uint16 + + // Length is the "length" field of a UDP packet. + Length uint16 + + // Checksum is the "checksum" field of a UDP packet. + Checksum uint16 +} + +// UDP represents a UDP header stored in a byte array. +type UDP []byte + +const ( + // UDPMinimumSize is the minimum size of a valid UDP packet. + UDPMinimumSize = 8 + + // UDPMaximumSize is the maximum size of a valid UDP packet. The length field + // in the UDP header is 16 bits as per RFC 768. + UDPMaximumSize = math.MaxUint16 + + // UDPProtocolNumber is UDP's transport protocol number. + UDPProtocolNumber tcpip.TransportProtocolNumber = 17 +) + +// SourcePort returns the "source port" field of the UDP header. +func (b UDP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[udpSrcPort:]) +} + +// DestinationPort returns the "destination port" field of the UDP header. +func (b UDP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[udpDstPort:]) +} + +// Length returns the "length" field of the UDP header. +func (b UDP) Length() uint16 { + return binary.BigEndian.Uint16(b[udpLength:]) +} + +// Payload returns the data contained in the UDP datagram. +func (b UDP) Payload() []byte { + return b[UDPMinimumSize:] +} + +// Checksum returns the "checksum" field of the UDP header. +func (b UDP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[udpChecksum:]) +} + +// SetSourcePort sets the "source port" field of the UDP header. +func (b UDP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[udpSrcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the UDP header. +func (b UDP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[udpDstPort:], port) +} + +// SetChecksum sets the "checksum" field of the UDP header. +func (b UDP) SetChecksum(xsum uint16) { + checksum.Put(b[udpChecksum:], xsum) +} + +// SetLength sets the "length" field of the UDP header. +func (b UDP) SetLength(length uint16) { + binary.BigEndian.PutUint16(b[udpLength:], length) +} + +// CalculateChecksum calculates the checksum of the UDP packet, given the +// checksum of the network-layer pseudo-header and the checksum of the payload. +func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { + // Calculate the rest of the checksum. + return checksum.Checksum(b[:UDPMinimumSize], partialChecksum) +} + +// IsChecksumValid returns true iff the UDP header's checksum is valid. +func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { + xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length()) + xsum = checksum.Combine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Encode encodes all the fields of the UDP header. +func (b UDP) Encode(u *UDPFields) { + b.SetSourcePort(u.SrcPort) + b.SetDestinationPort(u.DstPort) + b.SetLength(u.Length) + b.SetChecksum(u.Checksum) +} + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + +// UDPValid returns true if the pkt has a valid UDP header. It checks whether: +// - The length field is too small. +// - The length field is too large. +// - The checksum is invalid. +// +// UDPValid corresponds to net/netfilter/nf_conntrack_proto_udp.c:udp_error. +func UDPValid(hdr UDP, payloadChecksum func() uint16, payloadSize uint16, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (lengthValid, csumValid bool) { + if length := hdr.Length(); length > payloadSize+UDPMinimumSize || length < UDPMinimumSize { + return false, false + } + + if skipChecksumValidation { + return true, true + } + + // On IPv4, UDP checksum is optional, and a zero value means the transmitter + // omitted the checksum generation, as per RFC 768: + // + // An all zero transmitted checksum value means that the transmitter + // generated no checksum (for debugging or for higher level protocols that + // don't care). + // + // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: + // + // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP + // checksum is not optional. + if netProto == IPv4ProtocolNumber && hdr.Checksum() == 0 { + return true, true + } + + return true, hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum()) +} diff --git a/internal/gtcpip/seqnum/seqnum.go b/internal/gtcpip/seqnum/seqnum.go new file mode 100644 index 0000000..d3bea7d --- /dev/null +++ b/internal/gtcpip/seqnum/seqnum.go @@ -0,0 +1,62 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seqnum defines the types and methods for TCP sequence numbers such +// that they fit in 32-bit words and work properly when overflows occur. +package seqnum + +// Value represents the value of a sequence number. +type Value uint32 + +// Size represents the size (length) of a sequence number window. +type Size uint32 + +// LessThan checks if v is before w, i.e., v < w. +func (v Value) LessThan(w Value) bool { + return int32(v-w) < 0 +} + +// LessThanEq returns true if v==w or v is before i.e., v < w. +func (v Value) LessThanEq(w Value) bool { + if v == w { + return true + } + return v.LessThan(w) +} + +// InRange checks if v is in the range [a,b), i.e., a <= v < b. +func (v Value) InRange(a, b Value) bool { + return v-a < b-a +} + +// InWindow checks if v is in the window that starts at 'first' and spans 'size' +// sequence numbers. +func (v Value) InWindow(first Value, size Size) bool { + return v.InRange(first, first.Add(size)) +} + +// Add calculates the sequence number following the [v, v+s) window. +func (v Value) Add(s Size) Value { + return v + Value(s) +} + +// Size calculates the size of the window defined by [v, w). +func (v Value) Size(w Value) Size { + return Size(w - v) +} + +// UpdateForward updates v such that it becomes v + s. +func (v *Value) UpdateForward(s Size) { + *v += Value(s) +} diff --git a/internal/gtcpip/tcpip.go b/internal/gtcpip/tcpip.go new file mode 100644 index 0000000..60d2892 --- /dev/null +++ b/internal/gtcpip/tcpip.go @@ -0,0 +1,573 @@ +// Copyright 2024 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "crypto/rand" + "errors" + "fmt" + "math" + "math/bits" + "net" + "strconv" + "strings" + "time" +) + +// Using the header package here would cause an import cycle. +const ( + ipv4AddressSize = 4 + ipv4ProtocolNumber = 0x0800 + ipv6AddressSize = 16 + ipv6ProtocolNumber = 0x86dd +) + +const ( + // LinkAddressSize is the size of a MAC address. + LinkAddressSize = 6 +) + +// Known IP address. +var ( + IPv4Zero = []byte{0, 0, 0, 0} + IPv6Zero = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +) + +// Errors related to Subnet +var ( + errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") + errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") +) + +// TransportProtocolNumber is the number of a transport protocol. +type TransportProtocolNumber uint32 + +// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet +// frame. +// +// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml +type NetworkProtocolNumber uint32 + +// MonotonicTime is a monotonic clock reading. +// +// +stateify savable +type MonotonicTime struct { + nanoseconds int64 +} + +// String implements Stringer. +func (mt MonotonicTime) String() string { + return strconv.FormatInt(mt.nanoseconds, 10) +} + +// MonotonicTimeInfinite returns the monotonic timestamp as far away in the +// future as possible. +func MonotonicTimeInfinite() MonotonicTime { + return MonotonicTime{nanoseconds: math.MaxInt64} +} + +// Before reports whether the monotonic clock reading mt is before u. +func (mt MonotonicTime) Before(u MonotonicTime) bool { + return mt.nanoseconds < u.nanoseconds +} + +// After reports whether the monotonic clock reading mt is after u. +func (mt MonotonicTime) After(u MonotonicTime) bool { + return mt.nanoseconds > u.nanoseconds +} + +// Add returns the monotonic clock reading mt+d. +func (mt MonotonicTime) Add(d time.Duration) MonotonicTime { + return MonotonicTime{ + nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(), + } +} + +// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum) +// value that can be stored in a Duration, the maximum (or minimum) duration +// will be returned. To compute t-d for a duration d, use t.Add(-d). +func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration { + return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds)) +} + +// Milliseconds returns the time in milliseconds. +func (mt MonotonicTime) Milliseconds() int64 { + return mt.nanoseconds / 1e6 +} + +// A Clock provides the current time and schedules work for execution. +// +// Times returned by a Clock should always be used for application-visible +// time. Only monotonic times should be used for netstack internal timekeeping. +type Clock interface { + // Now returns the current local time. + Now() time.Time + + // NowMonotonic returns the current monotonic clock reading. + NowMonotonic() MonotonicTime + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) +} + +// Address is a byte slice cast as a string that represents the address of a +// network node. Or, in the case of unix endpoints, it may represent a path. +// +// +stateify savable +type Address struct { + addr [16]byte + length int +} + +// AddrFrom4 converts addr to an Address. +func AddrFrom4(addr [4]byte) Address { + ret := Address{ + length: 4, + } + // It's guaranteed that copy will return 4. + copy(ret.addr[:], addr[:]) + return ret +} + +// AddrFrom4Slice converts addr to an Address. It panics if len(addr) != 4. +func AddrFrom4Slice(addr []byte) Address { + if len(addr) != 4 { + panic(fmt.Sprintf("bad address length for address %v", addr)) + } + ret := Address{ + length: 4, + } + // It's guaranteed that copy will return 4. + copy(ret.addr[:], addr) + return ret +} + +// AddrFrom16 converts addr to an Address. +func AddrFrom16(addr [16]byte) Address { + ret := Address{ + length: 16, + } + // It's guaranteed that copy will return 16. + copy(ret.addr[:], addr[:]) + return ret +} + +// AddrFrom16Slice converts addr to an Address. It panics if len(addr) != 16. +func AddrFrom16Slice(addr []byte) Address { + if len(addr) != 16 { + panic(fmt.Sprintf("bad address length for address %v", addr)) + } + ret := Address{ + length: 16, + } + // It's guaranteed that copy will return 16. + copy(ret.addr[:], addr) + return ret +} + +// AddrFromSlice converts addr to an Address. It returns the Address zero value +// if len(addr) != 4 or 16. +func AddrFromSlice(addr []byte) Address { + switch len(addr) { + case ipv4AddressSize: + return AddrFrom4Slice(addr) + case ipv6AddressSize: + return AddrFrom16Slice(addr) + } + return Address{} +} + +// As4 returns a as a 4 byte array. It panics if the address length is not 4. +func (a Address) As4() [4]byte { + if a.Len() != 4 { + panic(fmt.Sprintf("bad address length for address %v", a.addr)) + } + return [4]byte(a.addr[:4]) +} + +// As16 returns a as a 16 byte array. It panics if the address length is not 16. +func (a Address) As16() [16]byte { + if a.Len() != 16 { + panic(fmt.Sprintf("bad address length for address %v", a.addr)) + } + return [16]byte(a.addr[:16]) +} + +// AsSlice returns a as a byte slice. Callers should be careful as it can +// return a window into existing memory. +// +// +checkescape +func (a *Address) AsSlice() []byte { + return a.addr[:a.length] +} + +// BitLen returns the length in bits of a. +func (a Address) BitLen() int { + return a.Len() * 8 +} + +// Len returns the length in bytes of a. +func (a Address) Len() int { + return a.length +} + +// WithPrefix returns the address with a prefix that represents a point subnet. +func (a Address) WithPrefix() AddressWithPrefix { + return AddressWithPrefix{ + Address: a, + PrefixLen: a.BitLen(), + } +} + +// Unspecified returns true if the address is unspecified. +func (a Address) Unspecified() bool { + for _, b := range a.addr { + if b != 0 { + return false + } + } + return true +} + +// Equal returns whether a and other are equal. It exists for use by the cmp +// library. +func (a Address) Equal(other Address) bool { + return a == other +} + +// MatchingPrefix returns the matching prefix length in bits. +// +// Panics if b and a have different lengths. +func (a Address) MatchingPrefix(b Address) uint8 { + const bitsInAByte = 8 + + if a.Len() != b.Len() { + panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b)) + } + + var prefix uint8 + for i := 0; i < a.length; i++ { + aByte := a.addr[i] + bByte := b.addr[i] + + if aByte == bByte { + prefix += bitsInAByte + continue + } + + // Count the remaining matching bits in the byte from MSbit to LSBbit. + mask := uint8(1) << (bitsInAByte - 1) + for { + if aByte&mask == bByte&mask { + prefix++ + mask >>= 1 + continue + } + + break + } + + break + } + + return prefix +} + +// AddressMask is a bitmask for an address. +// +// +stateify savable +type AddressMask struct { + mask [16]byte + length int +} + +// MaskFrom returns a Mask based on str. +// +// MaskFrom may allocate, and so should not be in hot paths. +func MaskFrom(str string) AddressMask { + mask := AddressMask{length: len(str)} + copy(mask.mask[:], str) + return mask +} + +// MaskFromBytes returns a Mask based on bs. +func MaskFromBytes(bs []byte) AddressMask { + mask := AddressMask{length: len(bs)} + copy(mask.mask[:], bs) + return mask +} + +// String implements Stringer. +func (m AddressMask) String() string { + return fmt.Sprintf("%x", m.mask) +} + +// AsSlice returns a as a byte slice. Callers should be careful as it can +// return a window into existing memory. +func (m *AddressMask) AsSlice() []byte { + return []byte(m.mask[:m.length]) +} + +// BitLen returns the length of the mask in bits. +func (m AddressMask) BitLen() int { + return m.length * 8 +} + +// Len returns the length of the mask in bytes. +func (m AddressMask) Len() int { + return m.length +} + +// Prefix returns the number of bits before the first host bit. +func (m AddressMask) Prefix() int { + p := 0 + for _, b := range m.mask[:m.length] { + p += bits.LeadingZeros8(^b) + } + return p +} + +// Equal returns whether m and other are equal. It exists for use by the cmp +// library. +func (m AddressMask) Equal(other AddressMask) bool { + return m == other +} + +// Subnet is a subnet defined by its address and mask. +// +// +stateify savable +type Subnet struct { + address Address + mask AddressMask +} + +// NewSubnet creates a new Subnet, checking that the address and mask are the same length. +func NewSubnet(a Address, m AddressMask) (Subnet, error) { + if a.Len() != m.Len() { + return Subnet{}, errSubnetLengthMismatch + } + for i := 0; i < a.Len(); i++ { + if a.addr[i]&^m.mask[i] != 0 { + return Subnet{}, errSubnetAddressMasked + } + } + return Subnet{a, m}, nil +} + +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + +// Contains returns true iff the address is of the same length and matches the +// subnet address and mask. +func (s *Subnet) Contains(a Address) bool { + if a.Len() != s.address.Len() { + return false + } + for i := 0; i < a.Len(); i++ { + if a.addr[i]&s.mask.mask[i] != s.address.addr[i] { + return false + } + } + return true +} + +// ID returns the subnet ID. +func (s *Subnet) ID() Address { + return s.address +} + +// Bits returns the number of ones (network bits) and zeros (host bits) in the +// subnet mask. +func (s *Subnet) Bits() (ones int, zeros int) { + ones = s.mask.Prefix() + return ones, s.mask.BitLen() - ones +} + +// Prefix returns the number of bits before the first host bit. +func (s *Subnet) Prefix() int { + return s.mask.Prefix() +} + +// Mask returns the subnet mask. +func (s *Subnet) Mask() AddressMask { + return s.mask +} + +// Broadcast returns the subnet's broadcast address. +func (s *Subnet) Broadcast() Address { + addrCopy := s.address + for i := 0; i < addrCopy.Len(); i++ { + addrCopy.addr[i] |= ^s.mask.mask[i] + } + return addrCopy +} + +// IsBroadcast returns true if the address is considered a broadcast address. +func (s *Subnet) IsBroadcast(address Address) bool { + // Only IPv4 supports the notion of a broadcast address. + if address.Len() != ipv4AddressSize { + return false + } + + // Normally, we would just compare address with the subnet's broadcast + // address but there is an exception where a simple comparison is not + // correct. This exception is for /31 and /32 IPv4 subnets where all + // addresses are considered valid host addresses. + // + // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that + // both addresses in a /31 subnet "MUST be interpreted as host addresses." + // + // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32 + // subnets. However, the same reasoning applies - if an exception is not + // made, then there do not exist any host addresses in a /32 subnet. RFC + // 4632 Section 3.1 also vaguely implies this interpretation by referring + // to addresses in /32 subnets as "host routes." + return s.Prefix() <= 30 && s.Broadcast() == address +} + +// Equal returns true if this Subnet is equal to the given Subnet. +func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. + return s == o +} + +// LinkAddress is a byte slice cast as a string that represents a link address. +// It is typically a 6-byte MAC address. +type LinkAddress string + +// String implements the fmt.Stringer interface. +func (a LinkAddress) String() string { + switch len(a) { + case 6: + return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5]) + default: + return fmt.Sprintf("%x", []byte(a)) + } +} + +// ParseMACAddress parses an IEEE 802 address. +// +// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func ParseMACAddress(s string) (LinkAddress, error) { + parts := strings.FieldsFunc(s, func(c rune) bool { + return c == ':' || c == '-' + }) + if len(parts) != LinkAddressSize { + return "", fmt.Errorf("inconsistent parts: %s", s) + } + addr := make([]byte, 0, len(parts)) + for _, part := range parts { + u, err := strconv.ParseUint(part, 16, 8) + if err != nil { + return "", fmt.Errorf("invalid hex digits: %s", s) + } + addr = append(addr, byte(u)) + } + return LinkAddress(addr), nil +} + +// GetRandMacAddr returns a mac address that can be used for local virtual devices. +func GetRandMacAddr() LinkAddress { + mac := make(net.HardwareAddr, LinkAddressSize) + rand.Read(mac) // Fill with random data. + mac[0] &^= 0x1 // Clear multicast bit. + mac[0] |= 0x2 // Set local assignment bit (IEEE802). + return LinkAddress(mac) +} + +// AddressWithPrefix is an address with its subnet prefix length. +// +// +stateify savable +type AddressWithPrefix struct { + // Address is a network address. + Address Address + + // PrefixLen is the subnet prefix length. + PrefixLen int +} + +// String implements the fmt.Stringer interface. +func (a AddressWithPrefix) String() string { + return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) +} + +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := a.Address.length + if a.PrefixLen <= 0 { + return Subnet{ + address: Address{length: addrLen}, + mask: AddressMask{length: addrLen}, + } + } + if a.PrefixLen >= addrLen*8 { + sub := Subnet{ + address: a.Address, + mask: AddressMask{length: addrLen}, + } + for i := 0; i < addrLen; i++ { + sub.mask.mask[i] = 0xff + } + return sub + } + + sa := Address{length: addrLen} + sm := AddressMask{length: addrLen} + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa.addr[i] = a.Address.addr[i] + sm.mask[i] = 0xff + n -= 8 + continue + } + sm.mask[i] = ^byte(0xff >> n) + sa.addr[i] = a.Address.addr[i] & sm.mask[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(sa, sm) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} diff --git a/network_name.go b/network_name.go index c136332..fa487fb 100644 --- a/network_name.go +++ b/network_name.go @@ -3,20 +3,21 @@ package tun import ( "strconv" - "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/header" F "github.com/sagernet/sing/common/format" N "github.com/sagernet/sing/common/network" ) func NetworkName(network uint8) string { - switch network { - case clashtcpip.TCP: + switch tcpip.TransportProtocolNumber(network) { + case header.TCPProtocolNumber: return N.NetworkTCP - case clashtcpip.UDP: + case header.UDPProtocolNumber: return N.NetworkUDP - case clashtcpip.ICMP: + case header.ICMPv4ProtocolNumber: return N.NetworkICMPv4 - case clashtcpip.ICMPv6: + case header.ICMPv6ProtocolNumber: return N.NetworkICMPv6 } return F.ToString(network) @@ -25,13 +26,13 @@ func NetworkName(network uint8) string { func NetworkFromName(name string) uint8 { switch name { case N.NetworkTCP: - return clashtcpip.TCP + return uint8(header.TCPProtocolNumber) case N.NetworkUDP: - return clashtcpip.UDP + return uint8(header.UDPProtocolNumber) case N.NetworkICMPv4: - return clashtcpip.ICMP + return uint8(header.ICMPv4ProtocolNumber) case N.NetworkICMPv6: - return clashtcpip.ICMPv6 + return uint8(header.ICMPv6ProtocolNumber) } parseNetwork, err := strconv.ParseUint(name, 10, 8) if err != nil { diff --git a/stack_gvisor.go b/stack_gvisor.go index 0ee5b5b..523360b 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -79,6 +79,11 @@ func (t *GVisor) Start() error { tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + pErr := t.handler.PrepareConnection(source, destination) + if pErr != nil { + r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid) + return + } conn := &gLazyConn{ parentCtx: t.ctx, stack: t.stack, @@ -86,11 +91,6 @@ func (t *GVisor) Start() error { localAddr: source.TCPAddr(), remoteAddr: destination.TCPAddr(), } - pErr := t.handler.PrepareConnection(source, destination) - if pErr != nil { - r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid) - return - } go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) diff --git a/stack_mixed.go b/stack_mixed.go index 8e1ab8a..d4e0607 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -6,13 +6,14 @@ import ( "time" "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" - "github.com/sagernet/gvisor/pkg/tcpip/header" + gHdr "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/link/channel" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/gvisor/pkg/waiter" - "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" @@ -104,7 +105,7 @@ func (m *Mixed) tunLoop() { } m.logger.Error(E.Cause(err, "read packet")) } - if n < clashtcpip.IPv4PacketMinLength { + if n < header.IPv4MinimumSize { continue } rawPacket := packetBuffer[:n] @@ -124,7 +125,7 @@ func (m *Mixed) wintunLoop(winTun WinTun) { if err != nil { return } - if len(packet) < clashtcpip.IPv4PacketMinLength { + if len(packet) < header.IPv4MinimumSize { release() continue } @@ -158,7 +159,7 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { } for i := 0; i < n; i++ { packetSize := packetSizes[i] - if packetSize < clashtcpip.IPv4PacketMinLength { + if packetSize < header.IPv4MinimumSize { continue } packetBuffer := packetBuffers[i] @@ -197,48 +198,48 @@ func (m *Mixed) processPacket(packet []byte) bool { return writeBack } -func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { +func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { writeBack = true - destination := packet.DestinationIP() + destination := ipHdr.DestinationAddr() if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { return } - switch packet.Protocol() { - case clashtcpip.TCP: - err = m.processIPv4TCP(packet, packet.Payload()) - case clashtcpip.UDP: + switch ipHdr.TransportProtocol() { + case header.TCPProtocolNumber: + writeBack, err = m.processIPv4TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), + Payload: buffer.MakeWithData(ipHdr), IsForwardedPacket: true, }) - m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) + m.endpoint.InjectInbound(gHdr.IPv4ProtocolNumber, pkt) pkt.DecRef() return - case clashtcpip.ICMP: - err = m.processIPv4ICMP(packet, packet.Payload()) + case header.ICMPv4ProtocolNumber: + err = m.processIPv4ICMP(ipHdr, ipHdr.Payload()) } return } -func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { +func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { writeBack = true - if !packet.DestinationIP().IsGlobalUnicast() { + if !ipHdr.DestinationAddr().IsGlobalUnicast() { return } - switch packet.Protocol() { - case clashtcpip.TCP: - err = m.processIPv6TCP(packet, packet.Payload()) - case clashtcpip.UDP: + switch ipHdr.TransportProtocol() { + case header.TCPProtocolNumber: + err = m.processIPv6TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), + Payload: buffer.MakeWithData(ipHdr), IsForwardedPacket: true, }) - m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) + m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt) pkt.DecRef() - case clashtcpip.ICMPv6: - err = m.processIPv6ICMP(packet, packet.Payload()) + case header.ICMPv6ProtocolNumber: + err = m.processIPv6ICMP(ipHdr, ipHdr.Payload()) } return } diff --git a/stack_system.go b/stack_system.go index 57e908b..5fcc2bf 100644 --- a/stack_system.go +++ b/stack_system.go @@ -7,7 +7,8 @@ import ( "syscall" "time" - "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" @@ -179,7 +180,7 @@ func (s *System) tunLoop() { } s.logger.Error(E.Cause(err, "read packet")) } - if n < clashtcpip.IPv4PacketMinLength { + if n < header.IPv4MinimumSize { continue } rawPacket := packetBuffer[:n] @@ -199,7 +200,7 @@ func (s *System) wintunLoop(winTun WinTun) { if err != nil { return } - if len(packet) < clashtcpip.IPv4PacketMinLength { + if len(packet) < header.IPv4MinimumSize { release() continue } @@ -233,7 +234,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { } for i := 0; i < n; i++ { packetSize := packetSizes[i] - if packetSize < clashtcpip.IPv4PacketMinLength { + if packetSize < header.IPv4MinimumSize { continue } packetBuffer := packetBuffers[i] @@ -300,83 +301,89 @@ func (s *System) acceptLoop(listener net.Listener) { } } } - go s.handler.NewConnectionEx(s.ctx, conn, M.SocksaddrFromNet(conn.RemoteAddr()), destination, nil) + go s.handler.NewConnectionEx(s.ctx, conn, M.SocksaddrFromNetIP(session.Source), destination, nil) } } -func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { +func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { writeBack = true - destination := packet.DestinationIP() + destination := ipHdr.DestinationAddr() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { return } - switch packet.Protocol() { - case clashtcpip.TCP: - err = s.processIPv4TCP(packet, packet.Payload()) - case clashtcpip.UDP: + switch ipHdr.TransportProtocol() { + case header.TCPProtocolNumber: + writeBack, err = s.processIPv4TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: writeBack = false - err = s.processIPv4UDP(packet, packet.Payload()) - case clashtcpip.ICMP: - err = s.processIPv4ICMP(packet, packet.Payload()) + err = s.processIPv4UDP(ipHdr, ipHdr.Payload()) + case header.ICMPv4ProtocolNumber: + err = s.processIPv4ICMP(ipHdr, ipHdr.Payload()) } return } -func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { +func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { writeBack = true - if !packet.DestinationIP().IsGlobalUnicast() { + if !ipHdr.DestinationAddr().IsGlobalUnicast() { return } - switch packet.Protocol() { - case clashtcpip.TCP: - err = s.processIPv6TCP(packet, packet.Payload()) - case clashtcpip.UDP: - writeBack = false - err = s.processIPv6UDP(packet, packet.Payload()) - case clashtcpip.ICMPv6: - err = s.processIPv6ICMP(packet, packet.Payload()) + switch ipHdr.TransportProtocol() { + case header.TCPProtocolNumber: + err = s.processIPv6TCP(ipHdr, ipHdr.Payload()) + case header.UDPProtocolNumber: + err = s.processIPv6UDP(ipHdr, ipHdr.Payload()) + case header.ICMPv6ProtocolNumber: + err = s.processIPv6ICMP(ipHdr, ipHdr.Payload()) } return } -func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { - source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) - destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) +func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, error) { + source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort()) + destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return nil + return true, nil } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { - return E.New("ipv4: tcp: session not found: ", destination.Port()) + return false, E.New("ipv4: tcp: session not found: ", destination.Port()) } - packet.SetSourceIP(session.Destination.Addr()) - header.SetSourcePort(session.Destination.Port()) - packet.SetDestinationIP(session.Source.Addr()) - header.SetDestinationPort(session.Source.Port()) + ipHdr.SetSourceAddr(session.Destination.Addr()) + tcpHdr.SetSourcePort(session.Destination.Port()) + ipHdr.SetDestinationAddr(session.Source.Addr()) + tcpHdr.SetDestinationPort(session.Source.Port()) } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - // TODO: implement rejects - return nil + // TODO: implement ICMP port unreachable + return false, nil } - packet.SetSourceIP(s.inet4Address) - header.SetSourcePort(natPort) - packet.SetDestinationIP(s.inet4ServerAddress) - header.SetDestinationPort(s.tcpPort) + ipHdr.SetSourceAddr(s.inet4Address) + tcpHdr.SetSourcePort(natPort) + ipHdr.SetDestinationAddr(s.inet4ServerAddress) + tcpHdr.SetDestinationPort(s.tcpPort) } if !s.txChecksumOffload { - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), + ))) } else { - header.OffloadChecksum() - packet.ResetChecksum() + tcpHdr.SetChecksum(0) } + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + return true, nil +} + +func (s *System) resetIPv4TCP(packet header.IPv4, header header.TCP) error { return nil } -func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { - source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) - destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) +func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error { + source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort()) + destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { return nil } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { @@ -384,58 +391,55 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip. if session == nil { return E.New("ipv6: tcp: session not found: ", destination.Port()) } - packet.SetSourceIP(session.Destination.Addr()) - header.SetSourcePort(session.Destination.Port()) - packet.SetDestinationIP(session.Source.Addr()) - header.SetDestinationPort(session.Source.Port()) + ipHdr.SetSourceAddr(session.Destination.Addr()) + tcpHdr.SetSourcePort(session.Destination.Port()) + ipHdr.SetSourceAddr(session.Source.Addr()) + tcpHdr.SetDestinationPort(session.Source.Port()) } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - // TODO: implement rejects + // TODO: implement ICMP port unreachable return nil } - packet.SetSourceIP(s.inet6Address) - header.SetSourcePort(natPort) - packet.SetDestinationIP(s.inet6ServerAddress) - header.SetDestinationPort(s.tcpPort6) + ipHdr.SetSourceAddr(s.inet6Address) + tcpHdr.SetSourcePort(natPort) + ipHdr.SetSourceAddr(s.inet6ServerAddress) + tcpHdr.SetDestinationPort(s.tcpPort6) } if !s.txChecksumOffload { - header.ResetChecksum(packet.PseudoSum()) + tcpHdr.SetChecksum(0) + tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), + ))) } else { - header.OffloadChecksum() + tcpHdr.SetChecksum(0) } return nil } -func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { - if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { +func (s *System) processIPv4UDP(ipHdr header.IPv4, udpHdr header.UDP) error { + if ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 { return E.New("ipv4: fragment dropped") } - if packet.FragmentOffset() != 0 { + if ipHdr.FragmentOffset() != 0 { return E.New("ipv4: udp: fragment dropped") } - if !header.Valid() { - return E.New("ipv4: udp: invalid packet") - } - source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort()) - destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort()) + source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort()) + destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort()) if !destination.Addr.IsGlobalUnicast() { return nil } - s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet) + s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr) return nil } -func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { - if !header.Valid() { - return E.New("ipv6: udp: invalid packet") - } - source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort()) - destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort()) +func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { + source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort()) + destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort()) if !destination.Addr.IsGlobalUnicast() { return nil } - s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet) + s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr) return nil } @@ -447,8 +451,8 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks } var writer N.PacketWriter if source.IsIPv4() { - packet := userData.(clashtcpip.IPv4Packet) - headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize + packet := userData.(header.IPv4) + headerLen := packet.HeaderLength() + header.UDPMinimumSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) writer = &systemUDPPacketWriter4{ @@ -459,8 +463,8 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks s.txChecksumOffload, } } else { - packet := userData.(clashtcpip.IPv6Packet) - headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize + packet := userData.(header.IPv6) + headerLen := len(packet) - int(packet.PayloadLength()) + header.UDPMinimumSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) writer = &systemUDPPacketWriter6{ @@ -474,32 +478,87 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks return true, s.ctx, writer, nil } -func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { - if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { +func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error { + if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 { return nil } - header.SetType(clashtcpip.ICMPTypePingResponse) - sourceAddress := packet.SourceIP() - packet.SetSourceIP(packet.DestinationIP()) - packet.SetDestinationIP(sourceAddress) - header.ResetChecksum() - packet.ResetChecksum() + icmpHdr.SetType(header.ICMPv4EchoReply) + sourceAddress := ipHdr.SourceAddr() + ipHdr.SetSourceAddr(ipHdr.DestinationAddr()) + ipHdr.SetDestinationAddr(sourceAddress) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, checksum.Checksum(icmpHdr.Payload(), 0))) + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return nil } -func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { - if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { +func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error { + if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 { return nil } - header.SetType(clashtcpip.ICMPv6EchoReply) - sourceAddress := packet.SourceIP() - packet.SetSourceIP(packet.DestinationIP()) - packet.SetDestinationIP(sourceAddress) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() + icmpHdr.SetType(header.ICMPv6EchoReply) + sourceAddress := ipHdr.SourceAddr() + ipHdr.SetSourceAddr(ipHdr.DestinationAddr()) + ipHdr.SetDestinationAddr(sourceAddress) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: ipHdr.SourceAddress(), + Dst: ipHdr.DestinationAddress(), + })) return nil } +/*func (s *System) WritePacket4(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error { + packet := buf.Get(header.IPv4MinimumSize + header.UDPMinimumSize + buffer.Len()) + ipHdr := header.IPv4(packet) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(packet)), + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: source.Addr(), + DstAddr: destination.Addr(), + }) + ipHdr.SetHeaderLength(header.IPv4MinimumSize) + udpHdr := header.UDP(ipHdr.Payload()) + udpHdr.Encode(&header.UDPFields{ + SrcPort: source.Port(), + DstPort: destination.Port(), + Length: uint16(header.UDPMinimumSize + buffer.Len()), + }) + copy(udpHdr.Payload(), buffer.Bytes()) + if !s.txChecksumOffload { + ... + } else { + udpHdr.SetChecksum(0) + } + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + return common.Error(s.tun.Write(packet)) +} + +func (s *System) WritePacket6(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error { + packet := buf.Get(header.IPv6MinimumSize + header.UDPMinimumSize + buffer.Len()) + ipHdr := header.IPv6(packet) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + buffer.Len()), + TransportProtocol: header.UDPProtocolNumber, + SrcAddr: source.Addr(), + DstAddr: destination.Addr(), + }) + udpHdr := header.UDP(ipHdr.Payload()) + udpHdr.Encode(&header.UDPFields{ + SrcPort: source.Port(), + DstPort: destination.Port(), + Length: uint16(header.UDPMinimumSize + buffer.Len()), + }) + copy(udpHdr.Payload(), buffer.Bytes()) + if !s.txChecksumOffload { + ... + } else { + udpHdr.SetChecksum(0) + } + return common.Error(s.tun.Write(packet)) +}*/ + type systemUDPPacketWriter4 struct { tun Tun frontHeadroom int @@ -514,21 +573,24 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) - ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) + ipHdr := header.IPv4(newPacket.Bytes()) ipHdr.SetTotalLength(uint16(newPacket.Len())) - ipHdr.SetDestinationIP(ipHdr.SourceIP()) - ipHdr.SetSourceIP(destination.Addr) - udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) + ipHdr.SetSourceAddress(ipHdr.SourceAddress()) + ipHdr.SetSourceAddr(destination.Addr) + udpHdr := header.UDP(ipHdr.Payload()) udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) - udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) + udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize)) if !w.txChecksumOffload { - udpHdr.ResetChecksum(ipHdr.PseudoSum()) - ipHdr.ResetChecksum() + udpHdr.SetChecksum(0) + udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), + ))) } else { - udpHdr.OffloadChecksum() - ipHdr.ResetChecksum() + udpHdr.SetChecksum(0) } + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET } else { @@ -551,19 +613,22 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) - ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) - udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len()) + ipHdr := header.IPv6(newPacket.Bytes()) + udpLen := uint16(header.UDPMinimumSize + buffer.Len()) ipHdr.SetPayloadLength(udpLen) - ipHdr.SetDestinationIP(ipHdr.SourceIP()) - ipHdr.SetSourceIP(destination.Addr) - udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) + ipHdr.SetDestinationAddress(ipHdr.SourceAddress()) + ipHdr.SetSourceAddr(destination.Addr) + udpHdr := header.UDP(ipHdr.Payload()) udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) if !w.txChecksumOffload { - udpHdr.ResetChecksum(ipHdr.PseudoSum()) + udpHdr.SetChecksum(0) + udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( + header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), + ))) } else { - udpHdr.OffloadChecksum() + udpHdr.SetChecksum(0) } if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 diff --git a/tun_linux_offload.go b/tun_linux_offload.go index 930b939..488d3ff 100644 --- a/tun_linux_offload.go +++ b/tun_linux_offload.go @@ -15,7 +15,7 @@ import ( "io" "unsafe" - "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing-tun/internal/gtcpip/checksum" E "github.com/sagernet/sing/common/exceptions" "golang.org/x/sys/unix" @@ -746,7 +746,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e } func checksumNoFold(b []byte, initial uint64) uint64 { - return initial + uint64(clashtcpip.Sum(b)) + return uint64(checksum.Checksum(b, uint16(initial))) } func checksumFold(b []byte, initial uint64) uint16 {