From 7e10d78d53e9a23f0bec4cd24afa84af767a5fab Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 23 Mar 2023 18:35:37 +0800 Subject: [PATCH] chore: share the same geodata in different rule --- component/geodata/utils.go | 42 +++++++++++++++++++++++++++----------- config/updateGeo.go | 2 ++ dns/filters.go | 25 +++-------------------- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/component/geodata/utils.go b/component/geodata/utils.go index 9e7e50b1..f1ea7151 100644 --- a/component/geodata/utils.go +++ b/component/geodata/utils.go @@ -2,6 +2,9 @@ package geodata import ( "fmt" + "golang.org/x/sync/singleflight" + "strings" + "github.com/Dreamacro/clash/component/geodata/router" C "github.com/Dreamacro/clash/constant" ) @@ -34,6 +37,8 @@ func Verify(name string) error { } } +var loadGeoSiteMatcherSF = singleflight.Group{} + func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) { if len(countryCode) == 0 { return nil, 0, fmt.Errorf("country code could not be empty") @@ -44,16 +49,19 @@ func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) not = true countryCode = countryCode[1:] } + countryCode = strings.ToLower(countryCode) - geoLoader, err := GetGeoDataLoader(geoLoaderName) - if err != nil { - return nil, 0, err - } - - domains, err := geoLoader.LoadGeoSite(countryCode) + v, err, _ := loadGeoSiteMatcherSF.Do(countryCode, func() (interface{}, error) { + geoLoader, err := GetGeoDataLoader(geoLoaderName) + if err != nil { + return nil, err + } + return geoLoader.LoadGeoSite(countryCode) + }) if err != nil { return nil, 0, err } + domains := v.([]*router.Domain) /** linear: linear algorithm @@ -68,25 +76,31 @@ func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) return matcher, len(domains), nil } +var loadGeoIPMatcherSF = singleflight.Group{} + func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) { if len(country) == 0 { return nil, 0, fmt.Errorf("country code could not be empty") } - geoLoader, err := GetGeoDataLoader(geoLoaderName) - if err != nil { - return nil, 0, err - } not := false if country[0] == '!' { not = true country = country[1:] } + country = strings.ToLower(country) - records, err := geoLoader.LoadGeoIP(country) + v, err, _ := loadGeoIPMatcherSF.Do(country, func() (interface{}, error) { + geoLoader, err := GetGeoDataLoader(geoLoaderName) + if err != nil { + return nil, err + } + return geoLoader.LoadGeoIP(country) + }) if err != nil { return nil, 0, err } + records := v.([]*router.CIDR) geoIP := &router.GeoIP{ CountryCode: country, @@ -98,6 +112,10 @@ func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) { if err != nil { return nil, 0, err } - return matcher, len(records), nil } + +func ClearCache() { + loadGeoSiteMatcherSF = singleflight.Group{} + loadGeoIPMatcherSF = singleflight.Group{} +} diff --git a/config/updateGeo.go b/config/updateGeo.go index a5f7b17b..698bd52d 100644 --- a/config/updateGeo.go +++ b/config/updateGeo.go @@ -63,6 +63,8 @@ func UpdateGeoDatabases() error { return fmt.Errorf("can't save GeoSite database file: %w", err) } + geodata.ClearCache() + return nil } diff --git a/dns/filters.go b/dns/filters.go index 11c85c2c..58b261ac 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -29,29 +29,10 @@ func (gf *geoipFilter) Match(ip netip.Addr) bool { } if geoIPMatcher == nil { - countryCode := "cn" - geoLoader, err := geodata.GetGeoDataLoader(geodata.LoaderName()) + var err error + geoIPMatcher, _, err = geodata.LoadGeoIPMatcher("CN") if err != nil { - log.Errorln("[GeoIPFilter] GetGeoDataLoader error: %s", err.Error()) - return false - } - - records, err := geoLoader.LoadGeoIP(countryCode) - if err != nil { - log.Errorln("[GeoIPFilter] LoadGeoIP error: %s", err.Error()) - return false - } - - geoIP := &router.GeoIP{ - CountryCode: countryCode, - Cidr: records, - ReverseMatch: false, - } - - geoIPMatcher, err = router.NewGeoIPMatcher(geoIP) - - if err != nil { - log.Errorln("[GeoIPFilter] NewGeoIPMatcher error: %s", err.Error()) + log.Errorln("[GeoIPFilter] LoadGeoIPMatcher error: %s", err.Error()) return false } }