diff --git a/option/dns.go b/option/dns.go index acca67b0..a135cf5c 100644 --- a/option/dns.go +++ b/option/dns.go @@ -90,27 +90,29 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error { } type DefaultDNSRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - Outbound Listable[string] `json:"outbound,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` + Inbound Listable[string] `json:"inbound,omitempty"` + Network string `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + Outbound Listable[string] `json:"outbound,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` } func (r DefaultDNSRule) IsValid() bool { @@ -132,7 +134,9 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool { common.ComparableSliceEquals(r.SourceGeoIP, other.SourceGeoIP) && common.ComparableSliceEquals(r.SourceIPCIDR, other.SourceIPCIDR) && common.ComparableSliceEquals(r.SourcePort, other.SourcePort) && + common.ComparableSliceEquals(r.SourcePortRange, other.SourcePortRange) && common.ComparableSliceEquals(r.Port, other.Port) && + common.ComparableSliceEquals(r.PortRange, other.PortRange) && common.ComparableSliceEquals(r.ProcessName, other.ProcessName) && common.ComparableSliceEquals(r.UserID, other.UserID) && common.ComparableSliceEquals(r.PackageName, other.PackageName) && diff --git a/option/route.go b/option/route.go index 4083fb1f..1b9cea34 100644 --- a/option/route.go +++ b/option/route.go @@ -87,28 +87,30 @@ func (r *Rule) UnmarshalJSON(bytes []byte) error { } type DefaultRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - IPVersion int `json:"ip_version,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - GeoIP Listable[string] `json:"geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - IPCIDR Listable[string] `json:"ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + Network string `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + GeoIP Listable[string] `json:"geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + IPCIDR Listable[string] `json:"ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` } func (r DefaultRule) IsValid() bool { @@ -133,7 +135,9 @@ func (r DefaultRule) Equals(other DefaultRule) bool { common.ComparableSliceEquals(r.SourceIPCIDR, other.SourceIPCIDR) && common.ComparableSliceEquals(r.IPCIDR, other.IPCIDR) && common.ComparableSliceEquals(r.SourcePort, other.SourcePort) && + common.ComparableSliceEquals(r.SourcePortRange, other.SourcePortRange) && common.ComparableSliceEquals(r.Port, other.Port) && + common.ComparableSliceEquals(r.PortRange, other.PortRange) && common.ComparableSliceEquals(r.ProcessName, other.ProcessName) && common.ComparableSliceEquals(r.PackageName, other.PackageName) && common.ComparableSliceEquals(r.User, other.User) && diff --git a/route/rule.go b/route/rule.go index c4148478..69dfd8d4 100644 --- a/route/rule.go +++ b/route/rule.go @@ -148,11 +148,27 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourcePortRange) > 0 { + item, err := NewPortRangeItem(true, options.SourcePortRange) + if err != nil { + return nil, E.Cause(err, "source_port_range") + } + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.Port) > 0 { item := NewPortItem(false, options.Port) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.PortRange) > 0 { + item, err := NewPortRangeItem(false, options.PortRange) + if err != nil { + return nil, E.Cause(err, "port_range") + } + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.ProcessName) > 0 { item := NewProcessItem(options.ProcessName) rule.items = append(rule.items, item) diff --git a/route/rule_dns.go b/route/rule_dns.go index 8fd0e2ac..bda8e446 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -121,11 +121,27 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.SourcePortRange) > 0 { + item, err := NewPortRangeItem(true, options.SourcePortRange) + if err != nil { + return nil, E.Cause(err, "source_port_range") + } + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.Port) > 0 { item := NewPortItem(false, options.Port) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.PortRange) > 0 { + item, err := NewPortRangeItem(false, options.PortRange) + if err != nil { + return nil, E.Cause(err, "port_range") + } + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } if len(options.ProcessName) > 0 { item := NewProcessItem(options.ProcessName) rule.items = append(rule.items, item) diff --git a/route/rule_port_range.go b/route/rule_port_range.go new file mode 100644 index 00000000..6e8b2724 --- /dev/null +++ b/route/rule_port_range.go @@ -0,0 +1,87 @@ +package route + +import ( + "strconv" + "strings" + + "github.com/sagernet/sing-box/adapter" + E "github.com/sagernet/sing/common/exceptions" +) + +var ErrBadPortRange = E.New("bad port range") + +var _ RuleItem = (*PortRangeItem)(nil) + +type PortRangeItem struct { + isSource bool + portRanges []string + portRangeList []rangeItem +} + +type rangeItem struct { + start uint16 + end uint16 +} + +func NewPortRangeItem(isSource bool, rangeList []string) (*PortRangeItem, error) { + portRangeList := make([]rangeItem, 0, len(rangeList)) + for _, portRange := range rangeList { + if !strings.Contains(portRange, ":") { + return nil, E.Extend(ErrBadPortRange, portRange) + } + subIndex := strings.Index(portRange, ":") + var start, end uint64 + var err error + if subIndex > 0 { + start, err = strconv.ParseUint(portRange[:subIndex], 10, 16) + if err != nil { + return nil, E.Cause(err, E.Extend(ErrBadPortRange, portRange)) + } + } + if subIndex == len(portRange)-1 { + end = 0xFF + } else { + end, err = strconv.ParseUint(portRange[subIndex+1:], 10, 16) + if err != nil { + return nil, E.Cause(err, E.Extend(ErrBadPortRange, portRange)) + } + } + portRangeList = append(portRangeList, rangeItem{uint16(start), uint16(end)}) + } + return &PortRangeItem{ + isSource: isSource, + portRanges: rangeList, + portRangeList: portRangeList, + }, nil +} + +func (r *PortRangeItem) Match(metadata *adapter.InboundContext) bool { + var port uint16 + if r.isSource { + port = metadata.Source.Port + } else { + port = metadata.Destination.Port + } + for _, portRange := range r.portRangeList { + if port >= portRange.start && port <= portRange.end { + return true + } + } + return false +} + +func (r *PortRangeItem) String() string { + var description string + if r.isSource { + description = "source_port_range=" + } else { + description = "port_range=" + } + pLen := len(r.portRanges) + if pLen == 1 { + description += r.portRanges[0] + } else { + description += "[" + strings.Join(r.portRanges, " ") + "]" + } + return description +}