Implement TCP and ICMP rejects

This commit is contained in:
世界 2024-10-22 21:18:32 +08:00
parent 1793988a6d
commit e95737eccb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 256 additions and 107 deletions

View file

@ -2,6 +2,7 @@ package tun
import (
"context"
"errors"
"net"
"net/netip"
"syscall"
@ -258,10 +259,10 @@ func (s *System) processPacket(packet []byte) bool {
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = s.processIPv4(packet)
case 6:
case header.IPv6Version:
writeBack, err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
@ -306,11 +307,11 @@ func (s *System) acceptLoop(listener net.Listener) {
}
func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
writeBack = true
destination := ipHdr.DestinationAddr()
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
writeBack = true
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
writeBack, err = s.processIPv4TCP(ipHdr, ipHdr.Payload())
@ -324,13 +325,13 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
}
func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
writeBack = true
if !ipHdr.DestinationAddr().IsGlobalUnicast() {
return
}
writeBack = true
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
err = s.processIPv6TCP(ipHdr, ipHdr.Payload())
writeBack, err = s.processIPv6TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
case header.ICMPv6ProtocolNumber:
@ -343,7 +344,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return true, nil
return false, nil
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
@ -356,8 +357,17 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement ICMP port unreachable
return false, nil
if errors.Is(err, ErrDrop) {
return false, nil
} else if errors.Is(err, syscall.ENETUNREACH) {
return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable)
} else if errors.Is(err, syscall.EHOSTUNREACH) {
return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
} else if errors.Is(err, syscall.ECONNREFUSED) {
return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
} else {
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet4Address)
tcpHdr.SetSourcePort(natPort)
@ -377,33 +387,84 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
return true, nil
}
func (s *System) resetIPv4TCP(packet header.IPv4, header header.TCP) error {
return nil
func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) error {
frontHeadroom := s.frontHeadroom + PacketOffset
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.TCPMinimumSize)
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.TCPMinimumSize)
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(newPacket.Len()),
Protocol: uint8(header.TCPProtocolNumber),
SrcAddr: origIPHdr.DestinationAddr(),
DstAddr: origIPHdr.SourceAddr(),
})
tcpHdr := header.TCP(ipHdr.Payload())
fields := header.TCPFields{
SrcPort: origTCPHdr.DestinationPort(),
DstPort: origTCPHdr.SourcePort(),
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
}
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
fields.SeqNum = origTCPHdr.AckNumber()
} else {
fields.Flags |= header.TCPFlagAck
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
ackNum++
}
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
ackNum++
}
fields.AckNum = ackNum
}
tcpHdr.Encode(&fields)
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error {
func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, error) {
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return nil
return false, nil
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return E.New("ipv6: tcp: session not found: ", destination.Port())
return false, E.New("ipv6: tcp: session not found: ", destination.Port())
}
ipHdr.SetSourceAddr(session.Destination.Addr())
tcpHdr.SetSourcePort(session.Destination.Port())
ipHdr.SetSourceAddr(session.Source.Addr())
ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port())
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
// TODO: implement ICMP port unreachable
return nil
if errors.Is(err, ErrDrop) {
return false, nil
} else if errors.Is(err, syscall.ENETUNREACH) {
return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable)
} else if errors.Is(err, syscall.EHOSTUNREACH) {
return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
} else if errors.Is(err, syscall.ECONNREFUSED) {
return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
} else {
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet6Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetSourceAddr(s.inet6ServerAddress)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort6)
}
if !s.txChecksumOffload {
@ -414,7 +475,51 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error {
} else {
tcpHdr.SetChecksum(0)
}
return nil
return true, nil
}
func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) error {
frontHeadroom := s.frontHeadroom + PacketOffset
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.TCPMinimumSize)
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.TCPMinimumSize)
ipHdr := header.IPv6(newPacket.Bytes())
ipHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.TCPMinimumSize),
TransportProtocol: header.TCPProtocolNumber,
SrcAddr: origIPHdr.DestinationAddr(),
DstAddr: origIPHdr.SourceAddr(),
})
tcpHdr := header.TCP(ipHdr.Payload())
fields := header.TCPFields{
SrcPort: origTCPHdr.DestinationPort(),
DstPort: origTCPHdr.SourcePort(),
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
}
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
fields.SeqNum = origTCPHdr.AckNumber()
} else {
fields.Flags |= header.TCPFlagAck
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
ackNum++
}
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
ackNum++
}
fields.AckNum = ackNum
}
tcpHdr.Encode(&fields)
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv4UDP(ipHdr header.IPv4, udpHdr header.UDP) error {
@ -444,9 +549,28 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
}
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(source, destination)
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
// TODO: implement ICMP port unreachable
if errors.Is(pErr, ErrDrop) {
} else if source.IsIPv4() {
ipHdr := userData.(header.IPv4)
if errors.Is(pErr, syscall.ENETUNREACH) {
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable)
} else if errors.Is(pErr, syscall.EHOSTUNREACH) {
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
} else {
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
}
} else {
ipHdr := userData.(header.IPv6)
if errors.Is(pErr, syscall.ENETUNREACH) {
s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable)
} else if errors.Is(pErr, syscall.EHOSTUNREACH) {
s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
} else {
s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
}
}
return false, nil, nil, nil
}
var writer N.PacketWriter
@ -492,6 +616,45 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
return nil
}
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
frontHeadroom := s.frontHeadroom + PacketOffset
mtu := s.mtu
const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
if mtu > maxIPData {
mtu = maxIPData
}
available := mtu - header.ICMPv4MinimumSize
if available < len(ipHdr)+header.ICMPv4MinimumErrorPayloadSize {
return nil
}
payload := ipHdr
if len(payload) > available {
payload = payload[:available]
}
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(payload))
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.ICMPv4MinimumSize+len(payload))
newIPHdr := header.IPv4(newPacket.Bytes())
newIPHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(newPacket.Len()),
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: ipHdr.DestinationAddr(),
DstAddr: ipHdr.SourceAddr(),
})
newIPHdr.SetChecksum(^newIPHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return nil
@ -508,57 +671,50 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error
return nil
}
/*func (s *System) WritePacket4(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error {
packet := buf.Get(header.IPv4MinimumSize + header.UDPMinimumSize + buffer.Len())
ipHdr := header.IPv4(packet)
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(len(packet)),
Protocol: uint8(header.UDPProtocolNumber),
SrcAddr: source.Addr(),
DstAddr: destination.Addr(),
})
ipHdr.SetHeaderLength(header.IPv4MinimumSize)
udpHdr := header.UDP(ipHdr.Payload())
udpHdr.Encode(&header.UDPFields{
SrcPort: source.Port(),
DstPort: destination.Port(),
Length: uint16(header.UDPMinimumSize + buffer.Len()),
})
copy(udpHdr.Payload(), buffer.Bytes())
if !s.txChecksumOffload {
...
} else {
udpHdr.SetChecksum(0)
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
frontHeadroom := s.frontHeadroom + PacketOffset
mtu := s.mtu
const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize
if mtu > maxIPv6Data {
mtu = maxIPv6Data
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return common.Error(s.tun.Write(packet))
available := mtu - header.ICMPv6ErrorHeaderSize
if available < header.IPv6MinimumSize {
return nil
}
payload := ipHdr
if len(payload) > available {
payload = payload[:available]
}
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + len(payload))
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.ICMPv6DstUnreachableMinimumSize+len(payload))
newIPHdr := header.IPv6(newPacket.Bytes())
newIPHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.ICMPv6DstUnreachableMinimumSize + len(payload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
SrcAddr: ipHdr.DestinationAddr(),
DstAddr: ipHdr.SourceAddr(),
})
icmpHdr := header.ICMPv6(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
Src: newIPHdr.SourceAddress(),
Dst: newIPHdr.DestinationAddress(),
PayloadCsum: checksum.Checksum(payload, 0),
PayloadLen: len(payload),
}))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) WritePacket6(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error {
packet := buf.Get(header.IPv6MinimumSize + header.UDPMinimumSize + buffer.Len())
ipHdr := header.IPv6(packet)
ipHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.UDPMinimumSize + buffer.Len()),
TransportProtocol: header.UDPProtocolNumber,
SrcAddr: source.Addr(),
DstAddr: destination.Addr(),
})
udpHdr := header.UDP(ipHdr.Payload())
udpHdr.Encode(&header.UDPFields{
SrcPort: source.Port(),
DstPort: destination.Port(),
Length: uint16(header.UDPMinimumSize + buffer.Len()),
})
copy(udpHdr.Payload(), buffer.Bytes())
if !s.txChecksumOffload {
...
} else {
udpHdr.SetChecksum(0)
}
return common.Error(s.tun.Write(packet))
}*/
type systemUDPPacketWriter4 struct {
tun Tun
frontHeadroom int