diff --git a/common/utils/ordered_map.go b/common/utils/ordered_map.go new file mode 100644 index 00000000..9ffc70c2 --- /dev/null +++ b/common/utils/ordered_map.go @@ -0,0 +1,40 @@ +package utils + +// modify from https://github.com/go-yaml/yaml/issues/698#issuecomment-1482026841 + +import ( + "errors" + "gopkg.in/yaml.v3" +) + +type StringMapSlice[V any] []StringMapSliceItem[V] + +type StringMapSliceItem[V any] struct { + Key string + Value V +} + +func (s *StringMapSlice[V]) UnmarshalYAML(value *yaml.Node) error { + for i := 0; i < len(value.Content); i += 2 { + if i+1 >= len(value.Content) { + return errors.New("not a dict") + } + item := StringMapSliceItem[V]{} + item.Key = value.Content[i].Value + if err := value.Content[i+1].Decode(&item.Value); err != nil { + return err + } + + *s = append(*s, item) + } + + return nil +} + +func (s *StringMapSlice[V]) Add(key string, value V) { + *s = append(*s, StringMapSliceItem[V]{Key: key, Value: value}) +} + +func (i *StringMapSliceItem[V]) Extract() (key string, value V) { + return i.Key, i.Value +} diff --git a/config/config.go b/config/config.go index 76ff9d68..7ef395b3 100644 --- a/config/config.go +++ b/config/config.go @@ -114,7 +114,7 @@ type DNS struct { DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` FakeIPRange *fakeip.Pool Hosts *trie.DomainTrie[resolver.HostValue] - NameServerPolicy map[string][]dns.NameServer + NameServerPolicy utils.StringMapSlice[[]dns.NameServer] ProxyServerNameserver []dns.NameServer } @@ -193,21 +193,21 @@ type RawNTP struct { } type RawDNS struct { - Enable bool `yaml:"enable"` - PreferH3 bool `yaml:"prefer-h3"` - IPv6 bool `yaml:"ipv6"` - IPv6Timeout uint `yaml:"ipv6-timeout"` - UseHosts bool `yaml:"use-hosts"` - NameServer []string `yaml:"nameserver"` - Fallback []string `yaml:"fallback"` - FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` - Listen string `yaml:"listen"` - EnhancedMode C.DNSMode `yaml:"enhanced-mode"` - FakeIPRange string `yaml:"fake-ip-range"` - FakeIPFilter []string `yaml:"fake-ip-filter"` - DefaultNameserver []string `yaml:"default-nameserver"` - NameServerPolicy map[string]any `yaml:"nameserver-policy"` - ProxyServerNameserver []string `yaml:"proxy-server-nameserver"` + Enable bool `yaml:"enable"` + PreferH3 bool `yaml:"prefer-h3"` + IPv6 bool `yaml:"ipv6"` + IPv6Timeout uint `yaml:"ipv6-timeout"` + UseHosts bool `yaml:"use-hosts"` + NameServer []string `yaml:"nameserver"` + Fallback []string `yaml:"fallback"` + FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` + Listen string `yaml:"listen"` + EnhancedMode C.DNSMode `yaml:"enhanced-mode"` + FakeIPRange string `yaml:"fake-ip-range"` + FakeIPFilter []string `yaml:"fake-ip-filter"` + DefaultNameserver []string `yaml:"default-nameserver"` + NameServerPolicy utils.StringMapSlice[any] `yaml:"nameserver-policy"` + ProxyServerNameserver []string `yaml:"proxy-server-nameserver"` } type RawFallbackFilter struct { @@ -1085,12 +1085,13 @@ func parsePureDNSServer(server string) string { } } } -func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (map[string][]dns.NameServer, error) { - policy := map[string][]dns.NameServer{} - updatedPolicy := make(map[string]interface{}) +func parseNameServerPolicy(nsPolicy utils.StringMapSlice[any], ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (utils.StringMapSlice[[]dns.NameServer], error) { + policy := utils.StringMapSlice[[]dns.NameServer]{} + updatedPolicy := utils.StringMapSlice[any]{} re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`) - for k, v := range nsPolicy { + for _, p := range nsPolicy { + k, v := p.Extract() if strings.Contains(k, ",") { if strings.Contains(k, "geosite:") { subkeys := strings.Split(k, ":") @@ -1098,7 +1099,7 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro subkeys = strings.Split(subkeys[0], ",") for _, subkey := range subkeys { newKey := "geosite:" + subkey - updatedPolicy[newKey] = v + updatedPolicy.Add(newKey, v) } } else if strings.Contains(k, "rule-set:") { subkeys := strings.Split(k, ":") @@ -1106,20 +1107,21 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro subkeys = strings.Split(subkeys[0], ",") for _, subkey := range subkeys { newKey := "rule-set:" + subkey - updatedPolicy[newKey] = v + updatedPolicy.Add(newKey, v) } } else if re.MatchString(k) { subkeys := strings.Split(k, ",") for _, subkey := range subkeys { - updatedPolicy[subkey] = v + updatedPolicy.Add(subkey, v) } } } else { - updatedPolicy[k] = v + updatedPolicy.Add(k, v) } } - for domain, server := range updatedPolicy { + for _, p := range updatedPolicy { + domain, server := p.Extract() servers, err := utils.ToStringSlice(server) if err != nil { return nil, err @@ -1144,7 +1146,7 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro } } } - policy[domain] = nameservers + policy.Add(domain, nameservers) } return policy, nil diff --git a/dns/policy.go b/dns/policy.go index a8b423e1..a58123e3 100644 --- a/dns/policy.go +++ b/dns/policy.go @@ -1,30 +1,50 @@ package dns -type Policy struct { - data []dnsClient +import ( + "github.com/metacubex/mihomo/component/trie" + C "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/constant/provider" +) + +type dnsPolicy interface { + Match(domain string) []dnsClient } -func (p *Policy) GetData() []dnsClient { - return p.data +type domainTriePolicy struct { + *trie.DomainTrie[[]dnsClient] } -func (p *Policy) Compare(p2 *Policy) int { - if p2 == nil { - return 1 +func (p domainTriePolicy) Match(domain string) []dnsClient { + record := p.DomainTrie.Search(domain) + if record != nil { + return record.Data() } - l1 := len(p.data) - l2 := len(p2.data) - if l1 == l2 { - return 0 - } - if l1 > l2 { - return 1 - } - return -1 + return nil } -func NewPolicy(data []dnsClient) *Policy { - return &Policy{ - data: data, - } +type geositePolicy struct { + matcher fallbackDomainFilter + inverse bool + dnsClients []dnsClient +} + +func (p geositePolicy) Match(domain string) []dnsClient { + matched := p.matcher.Match(domain) + if matched != p.inverse { + return p.dnsClients + } + return nil +} + +type domainSetPolicy struct { + domainSetProvider provider.RuleProvider + dnsClients []dnsClient +} + +func (p domainSetPolicy) Match(domain string) []dnsClient { + metadata := &C.Metadata{Host: domain} + if ok := p.domainSetProvider.Match(metadata); ok { + return p.dnsClients + } + return nil } diff --git a/dns/resolver.go b/dns/resolver.go index 610a06f0..7e45252e 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -8,6 +8,7 @@ import ( "time" "github.com/metacubex/mihomo/common/cache" + "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/component/fakeip" "github.com/metacubex/mihomo/component/geodata/router" "github.com/metacubex/mihomo/component/resolver" @@ -31,17 +32,6 @@ type result struct { Error error } -type geositePolicyRecord struct { - matcher fallbackDomainFilter - policy *Policy - inversedMatching bool -} - -type domainSetPolicyRecord struct { - domainSetProvider provider.RuleProvider - policy *Policy -} - type Resolver struct { ipv6 bool ipv6Timeout time.Duration @@ -52,9 +42,7 @@ type Resolver struct { fallbackIPFilters []fallbackIPFilter group singleflight.Group lruCache *cache.LruCache[string, *D.Msg] - policy *trie.DomainTrie[*Policy] - domainSetPolicy []domainSetPolicyRecord - geositePolicy []geositePolicyRecord + policy []dnsPolicy proxyServer []dnsClient } @@ -258,22 +246,9 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { return nil } - record := r.policy.Search(domain) - if record != nil { - p := record.Data() - return p.GetData() - } - - for _, geositeRecord := range r.geositePolicy { - matched := geositeRecord.matcher.Match(domain) - if matched != geositeRecord.inversedMatching { - return geositeRecord.policy.GetData() - } - } - metadata := &C.Metadata{Host: domain} - for _, domainSetRecord := range r.domainSetPolicy { - if ok := domainSetRecord.domainSetProvider.Match(metadata); ok { - return domainSetRecord.policy.GetData() + for _, policy := range r.policy { + if dnsClients := policy.Match(domain); len(dnsClients) > 0 { + return dnsClients } } return nil @@ -404,18 +379,17 @@ type FallbackFilter struct { } type Config struct { - Main, Fallback []NameServer - Default []NameServer - ProxyServer []NameServer - IPv6 bool - IPv6Timeout uint - EnhancedMode C.DNSMode - FallbackFilter FallbackFilter - Pool *fakeip.Pool - Hosts *trie.DomainTrie[resolver.HostValue] - Policy map[string][]NameServer - DomainSetPolicy map[provider.RuleProvider][]NameServer - GeositePolicy map[router.DomainMatcher][]NameServer + Main, Fallback []NameServer + Default []NameServer + ProxyServer []NameServer + IPv6 bool + IPv6Timeout uint + EnhancedMode C.DNSMode + FallbackFilter FallbackFilter + Pool *fakeip.Pool + Hosts *trie.DomainTrie[resolver.HostValue] + Policy utils.StringMapSlice[[]NameServer] + RuleProviders map[string]provider.RuleProvider } func NewResolver(config Config) *Resolver { @@ -442,38 +416,59 @@ func NewResolver(config Config) *Resolver { } if len(config.Policy) != 0 { - r.policy = trie.New[*Policy]() - for domain, nameserver := range config.Policy { - if strings.HasPrefix(strings.ToLower(domain), "geosite:") { - groupname := domain[8:] - inverse := false - if strings.HasPrefix(groupname, "!") { - inverse = true - groupname = groupname[1:] - } - log.Debugln("adding geosite policy: %s inversed %t", groupname, inverse) - matcher, err := NewGeoSite(groupname) - if err != nil { - continue - } - r.geositePolicy = append(r.geositePolicy, geositePolicyRecord{ - matcher: matcher, - policy: NewPolicy(transform(nameserver, defaultResolver)), - inversedMatching: inverse, - }) - } else { - _ = r.policy.Insert(domain, NewPolicy(transform(nameserver, defaultResolver))) + r.policy = make([]dnsPolicy, 0) + + var triePolicy *trie.DomainTrie[[]dnsClient] + insertTriePolicy := func() { + if triePolicy != nil { + triePolicy.Optimize() + r.policy = append(r.policy, domainTriePolicy{triePolicy}) + triePolicy = nil } } - r.policy.Optimize() - } - if len(config.DomainSetPolicy) > 0 { - for p, n := range config.DomainSetPolicy { - r.domainSetPolicy = append(r.domainSetPolicy, domainSetPolicyRecord{ - domainSetProvider: p, - policy: NewPolicy(transform(n, defaultResolver)), - }) + for _, p := range config.Policy { + domain, nameserver := p.Extract() + domain = strings.ToLower(domain) + + if temp := strings.Split(domain, ":"); len(temp) == 2 { + prefix := temp[0] + key := temp[1] + switch strings.ToLower(prefix) { + case "rule-set": + if p, ok := config.RuleProviders[key]; ok { + insertTriePolicy() + r.policy = append(r.policy, domainSetPolicy{ + domainSetProvider: p, + dnsClients: transform(nameserver, defaultResolver), + }) + continue + } + case "geosite": + inverse := false + if strings.HasPrefix(key, "!") { + inverse = true + key = key[1:] + } + log.Debugln("adding geosite policy: %s inversed %t", key, inverse) + matcher, err := NewGeoSite(key) + if err != nil { + continue + } + insertTriePolicy() + r.policy = append(r.policy, geositePolicy{ + matcher: matcher, + inverse: inverse, + dnsClients: transform(nameserver, defaultResolver), + }) + continue + } + } + if triePolicy == nil { + triePolicy = trie.New[[]dnsClient]() + } + _ = triePolicy.Insert(domain, transform(nameserver, defaultResolver)) } + insertTriePolicy() } fallbackIPFilters := []fallbackIPFilter{} @@ -508,7 +503,6 @@ func NewProxyServerHostResolver(old *Resolver) *Resolver { main: old.proxyServer, lruCache: old.lruCache, hosts: old.hosts, - policy: trie.New[*Policy](), ipv6Timeout: old.ipv6Timeout, } return r diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 6ea02989..f1c108cf 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -7,7 +7,6 @@ import ( "os" "runtime" "strconv" - "strings" "sync" "time" @@ -208,25 +207,6 @@ func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, gen dns.ReCreateServer("", nil, nil) return } - policy := make(map[string][]dns.NameServer) - domainSetPolicies := make(map[provider.RuleProvider][]dns.NameServer) - for key, nameservers := range c.NameServerPolicy { - temp := strings.Split(key, ":") - if len(temp) == 2 { - prefix := temp[0] - key := temp[1] - switch strings.ToLower(prefix) { - case "rule-set": - if p, ok := ruleProvider[key]; ok { - domainSetPolicies[p] = nameservers - } - case "geosite": - // TODO: - } - } else { - policy[key] = nameservers - } - } cfg := dns.Config{ Main: c.NameServer, Fallback: c.Fallback, @@ -242,10 +222,10 @@ func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, gen Domain: c.FallbackFilter.Domain, GeoSite: c.FallbackFilter.GeoSite, }, - Default: c.DefaultNameserver, - Policy: c.NameServerPolicy, - ProxyServer: c.ProxyServerNameserver, - DomainSetPolicy: domainSetPolicies, + Default: c.DefaultNameserver, + Policy: c.NameServerPolicy, + ProxyServer: c.ProxyServerNameserver, + RuleProviders: ruleProvider, } r := dns.NewResolver(cfg)