diff --git a/component/cidr/ipcidr_set.go b/component/cidr/ipcidr_set.go index dac2d38a..0cb55e36 100644 --- a/component/cidr/ipcidr_set.go +++ b/component/cidr/ipcidr_set.go @@ -1,12 +1,16 @@ package cidr import ( - "go4.org/netipx" + "fmt" "net/netip" + "unsafe" + + "go4.org/netipx" ) type IpCidrSet struct { - Ranges *netipx.IPSet + // must same with netipx.IPSet + rr []netipx.IPRange } func NewIpCidrSet() *IpCidrSet { @@ -18,15 +22,15 @@ func (set *IpCidrSet) AddIpCidrForString(ipCidr string) error { if err != nil { return err } - err = set.AddIpCidr(prefix) - return nil + return set.AddIpCidr(prefix) } func (set *IpCidrSet) AddIpCidr(ipCidr netip.Prefix) (err error) { - var b netipx.IPSetBuilder - b.AddSet(set.Ranges) - b.AddPrefix(ipCidr) - set.Ranges, err = b.IPSet() + if r := netipx.RangeOfPrefix(ipCidr); r.IsValid() { + set.rr = append(set.rr, r) + } else { + err = fmt.Errorf("not valid ipcidr range: %s", ipCidr) + } return } @@ -39,10 +43,24 @@ func (set *IpCidrSet) IsContainForString(ipString string) bool { } func (set *IpCidrSet) IsContain(ip netip.Addr) bool { - if set.Ranges == nil { - return false - } - return set.Ranges.Contains(ip.WithZone("")) + return set.toIPSet().Contains(ip.WithZone("")) } -func (set *IpCidrSet) Merge() {} +func (set *IpCidrSet) Merge() error { + var b netipx.IPSetBuilder + b.AddSet(set.toIPSet()) + i, err := b.IPSet() + if err != nil { + return err + } + set.fromIPSet(i) + return nil +} + +func (set *IpCidrSet) toIPSet() *netipx.IPSet { + return (*netipx.IPSet)(unsafe.Pointer(set)) +} + +func (set *IpCidrSet) fromIPSet(i *netipx.IPSet) { + *set = *(*IpCidrSet)(unsafe.Pointer(i)) +} diff --git a/component/cidr/ipcidr_set_test.go b/component/cidr/ipcidr_set_test.go index a6eaec84..b229aa2b 100644 --- a/component/cidr/ipcidr_set_test.go +++ b/component/cidr/ipcidr_set_test.go @@ -1,9 +1,7 @@ package cidr import ( - "go4.org/netipx" "testing" - "unsafe" ) func TestIpv4(t *testing.T) { @@ -99,7 +97,7 @@ func TestMerge(t *testing.T) { set.AddIpCidrForString(test.ipCidr2) set.Merge() - rangesLen := len(*(*[]netipx.IPRange)(unsafe.Pointer(set.Ranges))) + rangesLen := len(set.rr) if rangesLen != test.expectedLen { t.Errorf("Expected len: %v, got: %v", test.expectedLen, rangesLen) diff --git a/component/geodata/router/condition.go b/component/geodata/router/condition.go index c2ac8071..5261d2fe 100644 --- a/component/geodata/router/condition.go +++ b/component/geodata/router/condition.go @@ -137,9 +137,7 @@ func (m *GeoIPMatcher) Init(cidrs []*CIDR) error { return fmt.Errorf("error when loading GeoIP: %w", err) } } - m.cidrSet.Merge() - - return nil + return m.cidrSet.Merge() } func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) {