chore: NameServerPolicy will match inorder

This commit is contained in:
wwqgtxx 2023-11-08 19:29:26 +08:00
parent 17c9d507be
commit 575c1d4129
5 changed files with 178 additions and 142 deletions

View file

@ -0,0 +1,40 @@
package utils
// modify from https://github.com/go-yaml/yaml/issues/698#issuecomment-1482026841
import (
"errors"
"gopkg.in/yaml.v3"
)
type StringMapSlice[V any] []StringMapSliceItem[V]
type StringMapSliceItem[V any] struct {
Key string
Value V
}
func (s *StringMapSlice[V]) UnmarshalYAML(value *yaml.Node) error {
for i := 0; i < len(value.Content); i += 2 {
if i+1 >= len(value.Content) {
return errors.New("not a dict")
}
item := StringMapSliceItem[V]{}
item.Key = value.Content[i].Value
if err := value.Content[i+1].Decode(&item.Value); err != nil {
return err
}
*s = append(*s, item)
}
return nil
}
func (s *StringMapSlice[V]) Add(key string, value V) {
*s = append(*s, StringMapSliceItem[V]{Key: key, Value: value})
}
func (i *StringMapSliceItem[V]) Extract() (key string, value V) {
return i.Key, i.Value
}

View file

@ -114,7 +114,7 @@ type DNS struct {
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
FakeIPRange *fakeip.Pool FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue] Hosts *trie.DomainTrie[resolver.HostValue]
NameServerPolicy map[string][]dns.NameServer NameServerPolicy utils.StringMapSlice[[]dns.NameServer]
ProxyServerNameserver []dns.NameServer ProxyServerNameserver []dns.NameServer
} }
@ -193,21 +193,21 @@ type RawNTP struct {
} }
type RawDNS struct { type RawDNS struct {
Enable bool `yaml:"enable"` Enable bool `yaml:"enable"`
PreferH3 bool `yaml:"prefer-h3"` PreferH3 bool `yaml:"prefer-h3"`
IPv6 bool `yaml:"ipv6"` IPv6 bool `yaml:"ipv6"`
IPv6Timeout uint `yaml:"ipv6-timeout"` IPv6Timeout uint `yaml:"ipv6-timeout"`
UseHosts bool `yaml:"use-hosts"` UseHosts bool `yaml:"use-hosts"`
NameServer []string `yaml:"nameserver"` NameServer []string `yaml:"nameserver"`
Fallback []string `yaml:"fallback"` Fallback []string `yaml:"fallback"`
FallbackFilter RawFallbackFilter `yaml:"fallback-filter"` FallbackFilter RawFallbackFilter `yaml:"fallback-filter"`
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
EnhancedMode C.DNSMode `yaml:"enhanced-mode"` EnhancedMode C.DNSMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"` FakeIPRange string `yaml:"fake-ip-range"`
FakeIPFilter []string `yaml:"fake-ip-filter"` FakeIPFilter []string `yaml:"fake-ip-filter"`
DefaultNameserver []string `yaml:"default-nameserver"` DefaultNameserver []string `yaml:"default-nameserver"`
NameServerPolicy map[string]any `yaml:"nameserver-policy"` NameServerPolicy utils.StringMapSlice[any] `yaml:"nameserver-policy"`
ProxyServerNameserver []string `yaml:"proxy-server-nameserver"` ProxyServerNameserver []string `yaml:"proxy-server-nameserver"`
} }
type RawFallbackFilter struct { type RawFallbackFilter struct {
@ -1085,12 +1085,13 @@ func parsePureDNSServer(server string) string {
} }
} }
} }
func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (map[string][]dns.NameServer, error) { func parseNameServerPolicy(nsPolicy utils.StringMapSlice[any], ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (utils.StringMapSlice[[]dns.NameServer], error) {
policy := map[string][]dns.NameServer{} policy := utils.StringMapSlice[[]dns.NameServer]{}
updatedPolicy := make(map[string]interface{}) updatedPolicy := utils.StringMapSlice[any]{}
re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`) re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`)
for k, v := range nsPolicy { for _, p := range nsPolicy {
k, v := p.Extract()
if strings.Contains(k, ",") { if strings.Contains(k, ",") {
if strings.Contains(k, "geosite:") { if strings.Contains(k, "geosite:") {
subkeys := strings.Split(k, ":") subkeys := strings.Split(k, ":")
@ -1098,7 +1099,7 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
for _, subkey := range subkeys { for _, subkey := range subkeys {
newKey := "geosite:" + subkey newKey := "geosite:" + subkey
updatedPolicy[newKey] = v updatedPolicy.Add(newKey, v)
} }
} else if strings.Contains(k, "rule-set:") { } else if strings.Contains(k, "rule-set:") {
subkeys := strings.Split(k, ":") subkeys := strings.Split(k, ":")
@ -1106,20 +1107,21 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
for _, subkey := range subkeys { for _, subkey := range subkeys {
newKey := "rule-set:" + subkey newKey := "rule-set:" + subkey
updatedPolicy[newKey] = v updatedPolicy.Add(newKey, v)
} }
} else if re.MatchString(k) { } else if re.MatchString(k) {
subkeys := strings.Split(k, ",") subkeys := strings.Split(k, ",")
for _, subkey := range subkeys { for _, subkey := range subkeys {
updatedPolicy[subkey] = v updatedPolicy.Add(subkey, v)
} }
} }
} else { } else {
updatedPolicy[k] = v updatedPolicy.Add(k, v)
} }
} }
for domain, server := range updatedPolicy { for _, p := range updatedPolicy {
domain, server := p.Extract()
servers, err := utils.ToStringSlice(server) servers, err := utils.ToStringSlice(server)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1144,7 +1146,7 @@ func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]pro
} }
} }
} }
policy[domain] = nameservers policy.Add(domain, nameservers)
} }
return policy, nil return policy, nil

View file

@ -1,30 +1,50 @@
package dns package dns
type Policy struct { import (
data []dnsClient "github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/provider"
)
type dnsPolicy interface {
Match(domain string) []dnsClient
} }
func (p *Policy) GetData() []dnsClient { type domainTriePolicy struct {
return p.data *trie.DomainTrie[[]dnsClient]
} }
func (p *Policy) Compare(p2 *Policy) int { func (p domainTriePolicy) Match(domain string) []dnsClient {
if p2 == nil { record := p.DomainTrie.Search(domain)
return 1 if record != nil {
return record.Data()
} }
l1 := len(p.data) return nil
l2 := len(p2.data)
if l1 == l2 {
return 0
}
if l1 > l2 {
return 1
}
return -1
} }
func NewPolicy(data []dnsClient) *Policy { type geositePolicy struct {
return &Policy{ matcher fallbackDomainFilter
data: data, inverse bool
} dnsClients []dnsClient
}
func (p geositePolicy) Match(domain string) []dnsClient {
matched := p.matcher.Match(domain)
if matched != p.inverse {
return p.dnsClients
}
return nil
}
type domainSetPolicy struct {
domainSetProvider provider.RuleProvider
dnsClients []dnsClient
}
func (p domainSetPolicy) Match(domain string) []dnsClient {
metadata := &C.Metadata{Host: domain}
if ok := p.domainSetProvider.Match(metadata); ok {
return p.dnsClients
}
return nil
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/metacubex/mihomo/common/cache" "github.com/metacubex/mihomo/common/cache"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/fakeip" "github.com/metacubex/mihomo/component/fakeip"
"github.com/metacubex/mihomo/component/geodata/router" "github.com/metacubex/mihomo/component/geodata/router"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
@ -31,17 +32,6 @@ type result struct {
Error error Error error
} }
type geositePolicyRecord struct {
matcher fallbackDomainFilter
policy *Policy
inversedMatching bool
}
type domainSetPolicyRecord struct {
domainSetProvider provider.RuleProvider
policy *Policy
}
type Resolver struct { type Resolver struct {
ipv6 bool ipv6 bool
ipv6Timeout time.Duration ipv6Timeout time.Duration
@ -52,9 +42,7 @@ type Resolver struct {
fallbackIPFilters []fallbackIPFilter fallbackIPFilters []fallbackIPFilter
group singleflight.Group group singleflight.Group
lruCache *cache.LruCache[string, *D.Msg] lruCache *cache.LruCache[string, *D.Msg]
policy *trie.DomainTrie[*Policy] policy []dnsPolicy
domainSetPolicy []domainSetPolicyRecord
geositePolicy []geositePolicyRecord
proxyServer []dnsClient proxyServer []dnsClient
} }
@ -258,22 +246,9 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
return nil return nil
} }
record := r.policy.Search(domain) for _, policy := range r.policy {
if record != nil { if dnsClients := policy.Match(domain); len(dnsClients) > 0 {
p := record.Data() return dnsClients
return p.GetData()
}
for _, geositeRecord := range r.geositePolicy {
matched := geositeRecord.matcher.Match(domain)
if matched != geositeRecord.inversedMatching {
return geositeRecord.policy.GetData()
}
}
metadata := &C.Metadata{Host: domain}
for _, domainSetRecord := range r.domainSetPolicy {
if ok := domainSetRecord.domainSetProvider.Match(metadata); ok {
return domainSetRecord.policy.GetData()
} }
} }
return nil return nil
@ -404,18 +379,17 @@ type FallbackFilter struct {
} }
type Config struct { type Config struct {
Main, Fallback []NameServer Main, Fallback []NameServer
Default []NameServer Default []NameServer
ProxyServer []NameServer ProxyServer []NameServer
IPv6 bool IPv6 bool
IPv6Timeout uint IPv6Timeout uint
EnhancedMode C.DNSMode EnhancedMode C.DNSMode
FallbackFilter FallbackFilter FallbackFilter FallbackFilter
Pool *fakeip.Pool Pool *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue] Hosts *trie.DomainTrie[resolver.HostValue]
Policy map[string][]NameServer Policy utils.StringMapSlice[[]NameServer]
DomainSetPolicy map[provider.RuleProvider][]NameServer RuleProviders map[string]provider.RuleProvider
GeositePolicy map[router.DomainMatcher][]NameServer
} }
func NewResolver(config Config) *Resolver { func NewResolver(config Config) *Resolver {
@ -442,38 +416,59 @@ func NewResolver(config Config) *Resolver {
} }
if len(config.Policy) != 0 { if len(config.Policy) != 0 {
r.policy = trie.New[*Policy]() r.policy = make([]dnsPolicy, 0)
for domain, nameserver := range config.Policy {
if strings.HasPrefix(strings.ToLower(domain), "geosite:") { var triePolicy *trie.DomainTrie[[]dnsClient]
groupname := domain[8:] insertTriePolicy := func() {
inverse := false if triePolicy != nil {
if strings.HasPrefix(groupname, "!") { triePolicy.Optimize()
inverse = true r.policy = append(r.policy, domainTriePolicy{triePolicy})
groupname = groupname[1:] triePolicy = nil
}
log.Debugln("adding geosite policy: %s inversed %t", groupname, inverse)
matcher, err := NewGeoSite(groupname)
if err != nil {
continue
}
r.geositePolicy = append(r.geositePolicy, geositePolicyRecord{
matcher: matcher,
policy: NewPolicy(transform(nameserver, defaultResolver)),
inversedMatching: inverse,
})
} else {
_ = r.policy.Insert(domain, NewPolicy(transform(nameserver, defaultResolver)))
} }
} }
r.policy.Optimize() for _, p := range config.Policy {
} domain, nameserver := p.Extract()
if len(config.DomainSetPolicy) > 0 { domain = strings.ToLower(domain)
for p, n := range config.DomainSetPolicy {
r.domainSetPolicy = append(r.domainSetPolicy, domainSetPolicyRecord{ if temp := strings.Split(domain, ":"); len(temp) == 2 {
domainSetProvider: p, prefix := temp[0]
policy: NewPolicy(transform(n, defaultResolver)), key := temp[1]
}) switch strings.ToLower(prefix) {
case "rule-set":
if p, ok := config.RuleProviders[key]; ok {
insertTriePolicy()
r.policy = append(r.policy, domainSetPolicy{
domainSetProvider: p,
dnsClients: transform(nameserver, defaultResolver),
})
continue
}
case "geosite":
inverse := false
if strings.HasPrefix(key, "!") {
inverse = true
key = key[1:]
}
log.Debugln("adding geosite policy: %s inversed %t", key, inverse)
matcher, err := NewGeoSite(key)
if err != nil {
continue
}
insertTriePolicy()
r.policy = append(r.policy, geositePolicy{
matcher: matcher,
inverse: inverse,
dnsClients: transform(nameserver, defaultResolver),
})
continue
}
}
if triePolicy == nil {
triePolicy = trie.New[[]dnsClient]()
}
_ = triePolicy.Insert(domain, transform(nameserver, defaultResolver))
} }
insertTriePolicy()
} }
fallbackIPFilters := []fallbackIPFilter{} fallbackIPFilters := []fallbackIPFilter{}
@ -508,7 +503,6 @@ func NewProxyServerHostResolver(old *Resolver) *Resolver {
main: old.proxyServer, main: old.proxyServer,
lruCache: old.lruCache, lruCache: old.lruCache,
hosts: old.hosts, hosts: old.hosts,
policy: trie.New[*Policy](),
ipv6Timeout: old.ipv6Timeout, ipv6Timeout: old.ipv6Timeout,
} }
return r return r

View file

@ -7,7 +7,6 @@ import (
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -208,25 +207,6 @@ func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, gen
dns.ReCreateServer("", nil, nil) dns.ReCreateServer("", nil, nil)
return return
} }
policy := make(map[string][]dns.NameServer)
domainSetPolicies := make(map[provider.RuleProvider][]dns.NameServer)
for key, nameservers := range c.NameServerPolicy {
temp := strings.Split(key, ":")
if len(temp) == 2 {
prefix := temp[0]
key := temp[1]
switch strings.ToLower(prefix) {
case "rule-set":
if p, ok := ruleProvider[key]; ok {
domainSetPolicies[p] = nameservers
}
case "geosite":
// TODO:
}
} else {
policy[key] = nameservers
}
}
cfg := dns.Config{ cfg := dns.Config{
Main: c.NameServer, Main: c.NameServer,
Fallback: c.Fallback, Fallback: c.Fallback,
@ -242,10 +222,10 @@ func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, gen
Domain: c.FallbackFilter.Domain, Domain: c.FallbackFilter.Domain,
GeoSite: c.FallbackFilter.GeoSite, GeoSite: c.FallbackFilter.GeoSite,
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
Policy: c.NameServerPolicy, Policy: c.NameServerPolicy,
ProxyServer: c.ProxyServerNameserver, ProxyServer: c.ProxyServerNameserver,
DomainSetPolicy: domainSetPolicies, RuleProviders: ruleProvider,
} }
r := dns.NewResolver(cfg) r := dns.NewResolver(cfg)