From 9b89ff9f2dca997668a28d35bb8596df120a1dd1 Mon Sep 17 00:00:00 2001 From: adlyq Date: Tue, 6 Sep 2022 17:30:35 +0800 Subject: [PATCH] feat: support sub-rule, eg. rules: - SUB-RULE,(AND,((NETWORK,TCP),(DOMAIN-KEYWORD,google))),TEST2 - SUB-RULE,(GEOIP,!CN),TEST1 - MATCH,DIRECT sub-rules: TEST2: - MATCH,Proxy TEST1: - RULE-SET,Local,DIRECT,no-resolve - GEOSITE,CN,Domestic - GEOIP,CN,Domestic - MATCH,Proxy --- config/config.go | 130 ++++++++++++++++++++++++--- constant/rule.go | 5 +- rules/common/base.go | 2 +- rules/common/domain.go | 8 +- rules/common/domain_keyword.go | 8 +- rules/common/domain_suffix.go | 8 +- rules/common/final.go | 6 +- rules/common/geoip.go | 12 +-- rules/common/geosite.go | 8 +- rules/common/in_type.go | 6 +- rules/common/ipcidr.go | 6 +- rules/common/ipsuffix.go | 10 +-- rules/common/network_type.go | 4 +- rules/common/port.go | 6 +- rules/common/process.go | 6 +- rules/common/uid.go | 10 +-- rules/logic/and.go | 12 +-- rules/logic/common.go | 10 +-- rules/logic/logic_test.go | 23 +++-- rules/logic/not.go | 16 +++- rules/logic/or.go | 12 +-- rules/parser.go | 7 +- rules/provider/classical_strategy.go | 8 +- rules/provider/parse.go | 2 +- rules/provider/provider.go | 4 +- rules/provider/rule_set.go | 6 +- rules/sub_rule/sub_rules.go | 91 +++++++++++++++++++ tunnel/tunnel.go | 4 +- 28 files changed, 325 insertions(+), 105 deletions(-) create mode 100644 rules/sub_rule/sub_rules.go diff --git a/config/config.go b/config/config.go index 59d493c2..7daf1549 100644 --- a/config/config.go +++ b/config/config.go @@ -153,6 +153,7 @@ type Config struct { Hosts *trie.DomainTrie[netip.Addr] Profile *Profile Rules []C.Rule + SubRules *map[string][]C.Rule Users []auth.AuthUser Proxies map[string]C.Proxy Providers map[string]providerTypes.ProxyProvider @@ -233,6 +234,7 @@ type RawConfig struct { Proxy []map[string]any `yaml:"proxies"` ProxyGroup []map[string]any `yaml:"proxy-groups"` Rule []string `yaml:"rules"` + SubRules map[string][]string `yaml:"sub-rules"` } type RawGeoXUrl struct { @@ -381,12 +383,18 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { config.Proxies = proxies config.Providers = providers - rules, ruleProviders, err := parseRules(rawCfg, proxies) + subRules, ruleProviders, err := parseSubRules(rawCfg, proxies) + if err != nil { + return nil, err + } + config.SubRules = subRules + config.RuleProviders = ruleProviders + + rules, err := parseRules(rawCfg, proxies, subRules) if err != nil { return nil, err } config.Rules = rules - config.RuleProviders = ruleProviders hosts, err := parseHosts(rawCfg) if err != nil { @@ -563,8 +571,9 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[ return proxies, providersMap, nil } -func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]providerTypes.RuleProvider, error) { - ruleProviders := map[string]providerTypes.RuleProvider{} +func parseSubRules(cfg *RawConfig, proxies map[string]C.Proxy) (subRules *map[string][]C.Rule, ruleProviders map[string]providerTypes.RuleProvider, err error) { + ruleProviders = map[string]providerTypes.RuleProvider{} + subRules = &map[string][]C.Rule{} log.Infoln("Geodata Loader mode: %s", geodata.LoaderName()) // parse rule provider for name, mapping := range cfg.RuleProvider { @@ -577,6 +586,102 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin RP.SetRuleProvider(rp) } + for name, rawRules := range cfg.SubRules { + var rules []C.Rule + for idx, line := range rawRules { + rawRule := trimArr(strings.Split(line, ",")) + var ( + payload string + target string + params []string + ruleName = strings.ToUpper(rawRule[0]) + ) + + l := len(rawRule) + + if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" || ruleName == "SUB-RULE" { + target = rawRule[l-1] + payload = strings.Join(rawRule[1:l-1], ",") + } else { + if l < 2 { + return nil, nil, fmt.Errorf("sub-rules[%d] [%s] error: format invalid", idx, line) + } + if l < 4 { + rawRule = append(rawRule, make([]string, 4-l)...) + } + if ruleName == "MATCH" { + l = 2 + } + if l >= 3 { + l = 3 + payload = rawRule[1] + } + target = rawRule[l-1] + params = rawRule[l:] + } + + if _, ok := proxies[target]; !ok && ruleName != "SUB-RULE" { + return nil, nil, fmt.Errorf("sub-rules[%d:%s] [%s] error: proxy [%s] not found", idx, name, line, target) + } + + params = trimArr(params) + parsed, parseErr := R.ParseRule(ruleName, payload, target, params, subRules) + if parseErr != nil { + return nil, nil, fmt.Errorf("sub-rules[%d] [%s] error: %s", idx, line, parseErr.Error()) + } + + rules = append(rules, parsed) + } + (*subRules)[name] = rules + } + + if err = verifySubRule(subRules); err != nil { + return nil, nil, err + } + + return +} + +func verifySubRule(subRules *map[string][]C.Rule) error { + for name := range *subRules { + err := verifySubRuleCircularReferences(name, subRules, []string{}) + if err != nil { + return err + } + } + return nil +} + +func verifySubRuleCircularReferences(n string, subRules *map[string][]C.Rule, arr []string) error { + isInArray := func(v string, array []string) bool { + for _, c := range array { + if v == c { + return true + } + } + return false + } + + arr = append(arr, n) + for i, rule := range (*subRules)[n] { + if rule.RuleType() == C.SubRules { + if _, ok := (*subRules)[rule.Adapter()]; !ok { + return fmt.Errorf("sub-rule[%d:%s] error: [%s] not found", i, n, rule.Adapter()) + } + if isInArray(rule.Adapter(), arr) { + arr = append(arr, rule.Adapter()) + return fmt.Errorf("sub-rule error: circular references [%s]", strings.Join(arr, "->")) + } + + if err := verifySubRuleCircularReferences(rule.Adapter(), subRules, arr); err != nil { + return err + } + } + } + return nil +} + +func parseRules(cfg *RawConfig, proxies map[string]C.Proxy, subRules *map[string][]C.Rule) ([]C.Rule, error) { var rules []C.Rule rulesConfig := cfg.Rule @@ -592,12 +697,12 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin l := len(rule) - if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" { + if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" || ruleName == "SUB-RULE" { target = rule[l-1] payload = strings.Join(rule[1:l-1], ",") } else { if l < 2 { - return nil, nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line) + return nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line) } if l < 4 { rule = append(rule, make([]string, 4-l)...) @@ -612,15 +717,18 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin target = rule[l-1] params = rule[l:] } - if _, ok := proxies[target]; !ok { - return nil, nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) + if ruleName != "SUB-RULE" { + return nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target) + } else if _, ok = (*subRules)[target]; !ok { + return nil, fmt.Errorf("rules[%d] [%s] error: sub-rule [%s] not found", idx, line, target) + } } params = trimArr(params) - parsed, parseErr := R.ParseRule(ruleName, payload, target, params) + parsed, parseErr := R.ParseRule(ruleName, payload, target, params, subRules) if parseErr != nil { - return nil, nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error()) + return nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error()) } rules = append(rules, parsed) @@ -628,7 +736,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[strin runtime.GC() - return rules, ruleProviders, nil + return rules, nil } func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { diff --git a/constant/rule.go b/constant/rule.go index a403ac63..223dec6e 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -19,6 +19,7 @@ const ( Network Uid INTYPE + SubRules MATCH AND OR @@ -65,6 +66,8 @@ func (rt RuleType) String() string { return "Uid" case INTYPE: return "InType" + case SubRules: + return "SubRules" case AND: return "AND" case OR: @@ -78,7 +81,7 @@ func (rt RuleType) String() string { type Rule interface { RuleType() RuleType - Match(metadata *Metadata) bool + Match(metadata *Metadata) (bool, string) Adapter() string Payload() string ShouldResolveIP() bool diff --git a/rules/common/base.go b/rules/common/base.go index 42b4d770..abc2b434 100644 --- a/rules/common/base.go +++ b/rules/common/base.go @@ -5,7 +5,7 @@ import ( ) var ( - errPayload = errors.New("payload error") + errPayload = errors.New("payloadRule error") initFlag bool noResolve = "no-resolve" ) diff --git a/rules/common/domain.go b/rules/common/domain.go index 4a8a7d27..2f0e4f2f 100644 --- a/rules/common/domain.go +++ b/rules/common/domain.go @@ -18,11 +18,11 @@ func (d *Domain) RuleType() C.RuleType { return C.Domain } -func (d *Domain) Match(metadata *C.Metadata) bool { +func (d *Domain) Match(metadata *C.Metadata) (bool, string) { if metadata.AddrType != C.AtypDomainName { - return false + return false, "" } - return metadata.Host == d.domain + return metadata.Host == d.domain, d.adapter } func (d *Domain) Adapter() string { @@ -47,4 +47,4 @@ func NewDomain(domain string, adapter string) *Domain { } } -var _ C.Rule = (*Domain)(nil) +//var _ C.Rule = (*Domain)(nil) diff --git a/rules/common/domain_keyword.go b/rules/common/domain_keyword.go index 667b2861..58257544 100644 --- a/rules/common/domain_keyword.go +++ b/rules/common/domain_keyword.go @@ -18,12 +18,12 @@ func (dk *DomainKeyword) RuleType() C.RuleType { return C.DomainKeyword } -func (dk *DomainKeyword) Match(metadata *C.Metadata) bool { +func (dk *DomainKeyword) Match(metadata *C.Metadata) (bool, string) { if metadata.AddrType != C.AtypDomainName { - return false + return false, "" } domain := metadata.Host - return strings.Contains(domain, dk.keyword) + return strings.Contains(domain, dk.keyword), dk.adapter } func (dk *DomainKeyword) Adapter() string { @@ -48,4 +48,4 @@ func NewDomainKeyword(keyword string, adapter string) *DomainKeyword { } } -var _ C.Rule = (*DomainKeyword)(nil) +//var _ C.Rule = (*DomainKeyword)(nil) diff --git a/rules/common/domain_suffix.go b/rules/common/domain_suffix.go index c2edcd16..c9016744 100644 --- a/rules/common/domain_suffix.go +++ b/rules/common/domain_suffix.go @@ -18,12 +18,12 @@ func (ds *DomainSuffix) RuleType() C.RuleType { return C.DomainSuffix } -func (ds *DomainSuffix) Match(metadata *C.Metadata) bool { +func (ds *DomainSuffix) Match(metadata *C.Metadata) (bool, string) { if metadata.AddrType != C.AtypDomainName { - return false + return false, "" } domain := metadata.Host - return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix + return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix, ds.adapter } func (ds *DomainSuffix) Adapter() string { @@ -48,4 +48,4 @@ func NewDomainSuffix(suffix string, adapter string) *DomainSuffix { } } -var _ C.Rule = (*DomainSuffix)(nil) +//var _ C.Rule = (*DomainSuffix)(nil) diff --git a/rules/common/final.go b/rules/common/final.go index e42baf92..8aa5ed7b 100644 --- a/rules/common/final.go +++ b/rules/common/final.go @@ -13,8 +13,8 @@ func (f *Match) RuleType() C.RuleType { return C.MATCH } -func (f *Match) Match(metadata *C.Metadata) bool { - return true +func (f *Match) Match(metadata *C.Metadata) (bool, string) { + return true, f.adapter } func (f *Match) Adapter() string { @@ -32,4 +32,4 @@ func NewMatch(adapter string) *Match { } } -var _ C.Rule = (*Match)(nil) +//var _ C.Rule = (*Match)(nil) diff --git a/rules/common/geoip.go b/rules/common/geoip.go index 0a7670e1..ada862d2 100644 --- a/rules/common/geoip.go +++ b/rules/common/geoip.go @@ -25,10 +25,10 @@ func (g *GEOIP) RuleType() C.RuleType { return C.GEOIP } -func (g *GEOIP) Match(metadata *C.Metadata) bool { +func (g *GEOIP) Match(metadata *C.Metadata) (bool, string) { ip := metadata.DstIP if !ip.IsValid() { - return false + return false, "" } if strings.EqualFold(g.country, "LAN") { @@ -37,13 +37,13 @@ func (g *GEOIP) Match(metadata *C.Metadata) bool { ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || - resolver.IsFakeBroadcastIP(ip) + resolver.IsFakeBroadcastIP(ip), g.adapter } if !C.GeodataMode { record, _ := mmdb.Instance().Country(ip.AsSlice()) - return strings.EqualFold(record.Country.IsoCode, g.country) + return strings.EqualFold(record.Country.IsoCode, g.country), g.adapter } - return g.geoIPMatcher.Match(ip.AsSlice()) + return g.geoIPMatcher.Match(ip.AsSlice()), g.adapter } func (g *GEOIP) Adapter() string { @@ -98,4 +98,4 @@ func NewGEOIP(country string, adapter string, noResolveIP bool) (*GEOIP, error) return geoip, nil } -var _ C.Rule = (*GEOIP)(nil) +//var _ C.Rule = (*GEOIP)(nil) diff --git a/rules/common/geosite.go b/rules/common/geosite.go index c88fd78d..9897f349 100644 --- a/rules/common/geosite.go +++ b/rules/common/geosite.go @@ -23,13 +23,13 @@ func (gs *GEOSITE) RuleType() C.RuleType { return C.GEOSITE } -func (gs *GEOSITE) Match(metadata *C.Metadata) bool { +func (gs *GEOSITE) Match(metadata *C.Metadata) (bool, string) { if metadata.AddrType != C.AtypDomainName { - return false + return false, "" } domain := metadata.Host - return gs.matcher.ApplyDomain(domain) + return gs.matcher.ApplyDomain(domain), gs.adapter } func (gs *GEOSITE) Adapter() string { @@ -75,4 +75,4 @@ func NewGEOSITE(country string, adapter string) (*GEOSITE, error) { return geoSite, nil } -var _ C.Rule = (*GEOSITE)(nil) +//var _ C.Rule = (*GEOSITE)(nil) diff --git a/rules/common/in_type.go b/rules/common/in_type.go index c577c843..520c9594 100644 --- a/rules/common/in_type.go +++ b/rules/common/in_type.go @@ -13,13 +13,13 @@ type InType struct { payload string } -func (u *InType) Match(metadata *C.Metadata) bool { +func (u *InType) Match(metadata *C.Metadata) (bool, string) { for _, tp := range u.types { if metadata.Type == tp { - return true + return true, u.adapter } } - return false + return false, "" } func (u *InType) RuleType() C.RuleType { diff --git a/rules/common/ipcidr.go b/rules/common/ipcidr.go index 5ac17cf4..8ab6cf5a 100644 --- a/rules/common/ipcidr.go +++ b/rules/common/ipcidr.go @@ -35,12 +35,12 @@ func (i *IPCIDR) RuleType() C.RuleType { return C.IPCIDR } -func (i *IPCIDR) Match(metadata *C.Metadata) bool { +func (i *IPCIDR) Match(metadata *C.Metadata) (bool, string) { ip := metadata.DstIP if i.isSourceIP { ip = metadata.SrcIP } - return ip.IsValid() && i.ipnet.Contains(ip) + return ip.IsValid() && i.ipnet.Contains(ip), i.adapter } func (i *IPCIDR) Adapter() string { @@ -74,4 +74,4 @@ func NewIPCIDR(s string, adapter string, opts ...IPCIDROption) (*IPCIDR, error) return ipcidr, nil } -var _ C.Rule = (*IPCIDR)(nil) +//var _ C.Rule = (*IPCIDR)(nil) diff --git a/rules/common/ipsuffix.go b/rules/common/ipsuffix.go index 18271244..b01557dc 100644 --- a/rules/common/ipsuffix.go +++ b/rules/common/ipsuffix.go @@ -22,7 +22,7 @@ func (is *IPSuffix) RuleType() C.RuleType { return C.IPSuffix } -func (is *IPSuffix) Match(metadata *C.Metadata) bool { +func (is *IPSuffix) Match(metadata *C.Metadata) (bool, string) { ip := metadata.DstIP if is.isSourceIP { ip = metadata.SrcIP @@ -30,7 +30,7 @@ func (is *IPSuffix) Match(metadata *C.Metadata) bool { mIPBytes := ip.AsSlice() if len(is.ipBytes) != len(mIPBytes) { - return false + return false, "" } size := len(mIPBytes) @@ -38,15 +38,15 @@ func (is *IPSuffix) Match(metadata *C.Metadata) bool { for i := bits / 8; i > 0; i-- { if is.ipBytes[size-i] != mIPBytes[size-i] { - return false + return false, "" } } if (is.ipBytes[size-bits/8-1] << (8 - bits%8)) != (mIPBytes[size-bits/8-1] << (8 - bits%8)) { - return false + return false, "" } - return true + return true, is.adapter } func (is *IPSuffix) Adapter() string { diff --git a/rules/common/network_type.go b/rules/common/network_type.go index 107df91e..fb6b5077 100644 --- a/rules/common/network_type.go +++ b/rules/common/network_type.go @@ -36,8 +36,8 @@ func (n *NetworkType) RuleType() C.RuleType { return C.Network } -func (n *NetworkType) Match(metadata *C.Metadata) bool { - return n.network == metadata.NetWork +func (n *NetworkType) Match(metadata *C.Metadata) (bool, string) { + return n.network == metadata.NetWork, n.adapter } func (n *NetworkType) Adapter() string { diff --git a/rules/common/port.go b/rules/common/port.go index 06fde6c2..270e5b20 100644 --- a/rules/common/port.go +++ b/rules/common/port.go @@ -24,11 +24,11 @@ func (p *Port) RuleType() C.RuleType { return C.DstPort } -func (p *Port) Match(metadata *C.Metadata) bool { +func (p *Port) Match(metadata *C.Metadata) (bool, string) { if p.isSource { - return p.matchPortReal(metadata.SrcPort) + return p.matchPortReal(metadata.SrcPort), p.adapter } - return p.matchPortReal(metadata.DstPort) + return p.matchPortReal(metadata.DstPort), p.adapter } func (p *Port) Adapter() string { diff --git a/rules/common/process.go b/rules/common/process.go index 69d83632..9263c32d 100644 --- a/rules/common/process.go +++ b/rules/common/process.go @@ -17,12 +17,12 @@ func (ps *Process) RuleType() C.RuleType { return C.Process } -func (ps *Process) Match(metadata *C.Metadata) bool { +func (ps *Process) Match(metadata *C.Metadata) (bool, string) { if ps.nameOnly { - return strings.EqualFold(metadata.Process, ps.process) + return strings.EqualFold(metadata.Process, ps.process), ps.adapter } - return strings.EqualFold(metadata.ProcessPath, ps.process) + return strings.EqualFold(metadata.ProcessPath, ps.process), ps.adapter } func (ps *Process) Adapter() string { diff --git a/rules/common/uid.go b/rules/common/uid.go index f7ea4875..5a989f67 100644 --- a/rules/common/uid.go +++ b/rules/common/uid.go @@ -71,10 +71,10 @@ func (u *Uid) RuleType() C.RuleType { return C.Uid } -func (u *Uid) Match(metadata *C.Metadata) bool { +func (u *Uid) Match(metadata *C.Metadata) (bool, string) { srcPort, err := strconv.ParseUint(metadata.SrcPort, 10, 16) if err != nil { - return false + return false, "" } var uid int32 if metadata.Uid != nil { @@ -83,15 +83,15 @@ func (u *Uid) Match(metadata *C.Metadata) bool { metadata.Uid = &uid } else { log.Warnln("[UID] could not get uid from %s", metadata.String()) - return false + return false, "" } for _, _uid := range u.uids { if _uid.Contains(uid) { - return true + return true, u.adapter } } - return false + return false, "" } func (u *Uid) Adapter() string { diff --git a/rules/logic/and.go b/rules/logic/and.go index 5a9b4d0f..a8fc1bad 100644 --- a/rules/logic/and.go +++ b/rules/logic/and.go @@ -20,9 +20,9 @@ func (A *AND) ShouldFindProcess() bool { } func NewAND(payload string, adapter string, - parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) (*AND, error) { + parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) (*AND, error) { and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter} - rules, err := parseRuleByPayload(payload, parse) + rules, err := ParseRuleByPayload(payload, parse) if err != nil { return nil, err } @@ -45,14 +45,14 @@ func (A *AND) RuleType() C.RuleType { return C.AND } -func (A *AND) Match(metadata *C.Metadata) bool { +func (A *AND) Match(metadata *C.Metadata) (bool, string) { for _, rule := range A.rules { - if !rule.Match(metadata) { - return false + if m, _ := rule.Match(metadata); !m { + return false, "" } } - return true + return true, A.adapter } func (A *AND) Adapter() string { diff --git a/rules/logic/common.go b/rules/logic/common.go index 736ead43..080771ba 100644 --- a/rules/logic/common.go +++ b/rules/logic/common.go @@ -9,7 +9,7 @@ import ( _ "unsafe" ) -func parseRuleByPayload(payload string, parseRule func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) ([]C.Rule, error) { +func ParseRuleByPayload(payload string, parseRule func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) ([]C.Rule, error) { regex, err := regexp.Compile("\\(.*\\)") if err != nil { return nil, err @@ -59,13 +59,13 @@ func payloadToRule(subPayload string, parseRule func(tp, payload, target string, return parseRule(tp, param[0], "", param[1:]) } -func parseLogicSubRule(parseRule func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) { +func parseLogicSubRule(parseRule func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) { return func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) { switch tp { - case "MATCH": - return nil, fmt.Errorf("unsupported rule type on logic rule") + case "MATCH", "SUB-RULE": + return nil, fmt.Errorf("unsupported rule type [%s] on logic rule", tp) default: - return parseRule(tp, payload, target, params) + return parseRule(tp, payload, target, params, nil) } } } diff --git a/rules/logic/logic_test.go b/rules/logic/logic_test.go index b7ea9ebe..dcc92d02 100644 --- a/rules/logic/logic_test.go +++ b/rules/logic/logic_test.go @@ -3,13 +3,15 @@ package logic import ( "fmt" "github.com/Dreamacro/clash/constant" + C "github.com/Dreamacro/clash/constant" RC "github.com/Dreamacro/clash/rules/common" RP "github.com/Dreamacro/clash/rules/provider" + "github.com/Dreamacro/clash/rules/sub_rule" "github.com/stretchr/testify/assert" "testing" ) -func ParseRule(tp, payload, target string, params []string) (parsed constant.Rule, parseErr error) { +func ParseRule(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed constant.Rule, parseErr error) { switch tp { case "DOMAIN": parsed = RC.NewDomain(payload, target) @@ -46,6 +48,8 @@ func ParseRule(tp, payload, target string, params []string) (parsed constant.Rul parsed, parseErr = RC.NewUid(payload, target) case "IN-TYPE": parsed, parseErr = RC.NewInType(payload, target) + case "SUB-RULE": + parsed, parseErr = sub_rule.NewSubRule(payload, target, subRules, ParseRule) case "AND": parsed, parseErr = NewAND(payload, target, ParseRule) case "OR": @@ -54,7 +58,7 @@ func ParseRule(tp, payload, target string, params []string) (parsed constant.Rul parsed, parseErr = NewNOT(payload, target, ParseRule) case "RULE-SET": noResolve := RC.HasNoResolve(params) - parsed, parseErr = RP.NewRuleSet(payload, target, noResolve, ParseRule) + parsed, parseErr = RP.NewRuleSet(payload, target, noResolve) case "MATCH": parsed = RC.NewMatch(target) parseErr = nil @@ -70,12 +74,13 @@ func TestAND(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, "DIRECT", and.adapter) assert.Equal(t, false, and.ShouldResolveIP()) - assert.Equal(t, true, and.Match(&constant.Metadata{ + m, _ := and.Match(&constant.Metadata{ Host: "baidu.com", AddrType: constant.AtypDomainName, NetWork: constant.TCP, DstPort: "20000", - })) + }) + assert.Equal(t, true, m) and, err = NewAND("(DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT", ParseRule) assert.NotEqual(t, nil, err) @@ -87,9 +92,10 @@ func TestAND(t *testing.T) { func TestNOT(t *testing.T) { not, err := NewNOT("((DST-PORT,6000-6500))", "REJECT", ParseRule) assert.Equal(t, nil, err) - assert.Equal(t, false, not.Match(&constant.Metadata{ + m, _ := not.Match(&constant.Metadata{ DstPort: "6100", - })) + }) + assert.Equal(t, false, m) _, err = NewNOT("((DST-PORT,5600-6666),(DOMAIN,baidu.com))", "DIRECT", ParseRule) assert.NotEqual(t, nil, err) @@ -101,8 +107,9 @@ func TestNOT(t *testing.T) { func TestOR(t *testing.T) { or, err := NewOR("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT", ParseRule) assert.Equal(t, nil, err) - assert.Equal(t, true, or.Match(&constant.Metadata{ + m, _ := or.Match(&constant.Metadata{ NetWork: constant.TCP, - })) + }) + assert.Equal(t, true, m) assert.Equal(t, false, or.ShouldResolveIP()) } diff --git a/rules/logic/not.go b/rules/logic/not.go index dc14e1d1..e584a615 100644 --- a/rules/logic/not.go +++ b/rules/logic/not.go @@ -17,9 +17,9 @@ func (not *NOT) ShouldFindProcess() bool { return false } -func NewNOT(payload string, adapter string, parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) (*NOT, error) { +func NewNOT(payload string, adapter string, parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) (*NOT, error) { not := &NOT{Base: &common.Base{}, adapter: adapter} - rule, err := parseRuleByPayload(payload, parse) + rule, err := ParseRuleByPayload(payload, parse) if err != nil { return nil, err } @@ -38,8 +38,16 @@ func (not *NOT) RuleType() C.RuleType { return C.NOT } -func (not *NOT) Match(metadata *C.Metadata) bool { - return not.rule == nil || !not.rule.Match(metadata) +func (not *NOT) Match(metadata *C.Metadata) (bool, string) { + if not.rule == nil { + return true, not.adapter + } + + if m, _ := not.rule.Match(metadata); m { + return true, not.adapter + } + + return false, "" } func (not *NOT) Adapter() string { diff --git a/rules/logic/or.go b/rules/logic/or.go index 08698b02..d1aae9ac 100644 --- a/rules/logic/or.go +++ b/rules/logic/or.go @@ -23,14 +23,14 @@ func (or *OR) RuleType() C.RuleType { return C.OR } -func (or *OR) Match(metadata *C.Metadata) bool { +func (or *OR) Match(metadata *C.Metadata) (bool, string) { for _, rule := range or.rules { - if rule.Match(metadata) { - return true + if m, _ := rule.Match(metadata); m { + return true, or.adapter } } - return false + return false, "" } func (or *OR) Adapter() string { @@ -45,9 +45,9 @@ func (or *OR) ShouldResolveIP() bool { return or.needIP } -func NewOR(payload string, adapter string, parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) (*OR, error) { +func NewOR(payload string, adapter string, parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) (*OR, error) { or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter} - rules, err := parseRuleByPayload(payload, parse) + rules, err := ParseRuleByPayload(payload, parse) if err != nil { return nil, err } diff --git a/rules/parser.go b/rules/parser.go index c6ca8847..878b0633 100644 --- a/rules/parser.go +++ b/rules/parser.go @@ -6,9 +6,10 @@ import ( RC "github.com/Dreamacro/clash/rules/common" "github.com/Dreamacro/clash/rules/logic" RP "github.com/Dreamacro/clash/rules/provider" + "github.com/Dreamacro/clash/rules/sub_rule" ) -func ParseRule(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) { +func ParseRule(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error) { switch tp { case "DOMAIN": parsed = RC.NewDomain(payload, target) @@ -45,6 +46,8 @@ func ParseRule(tp, payload, target string, params []string) (parsed C.Rule, pars parsed, parseErr = RC.NewUid(payload, target) case "IN-TYPE": parsed, parseErr = RC.NewInType(payload, target) + case "SUB-RULE": + parsed, parseErr = sub_rule.NewSubRule(payload, target, subRules, ParseRule) case "AND": parsed, parseErr = logic.NewAND(payload, target, ParseRule) case "OR": @@ -53,7 +56,7 @@ func ParseRule(tp, payload, target string, params []string) (parsed C.Rule, pars parsed, parseErr = logic.NewNOT(payload, target, ParseRule) case "RULE-SET": noResolve := RC.HasNoResolve(params) - parsed, parseErr = RP.NewRuleSet(payload, target, noResolve, ParseRule) + parsed, parseErr = RP.NewRuleSet(payload, target, noResolve) case "MATCH": parsed = RC.NewMatch(target) parseErr = nil diff --git a/rules/provider/classical_strategy.go b/rules/provider/classical_strategy.go index 430acead..727688fc 100644 --- a/rules/provider/classical_strategy.go +++ b/rules/provider/classical_strategy.go @@ -16,7 +16,7 @@ type classicalStrategy struct { func (c *classicalStrategy) Match(metadata *C.Metadata) bool { for _, rule := range c.rules { - if rule.Match(metadata) { + if m, _ := rule.Match(metadata); m { return true } } @@ -66,13 +66,13 @@ func ruleParse(ruleRaw string) (string, string, []string) { return "", "", nil } -func NewClassicalStrategy(parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) *classicalStrategy { +func NewClassicalStrategy(parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) *classicalStrategy { return &classicalStrategy{rules: []C.Rule{}, parse: func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) { switch tp { - case "MATCH": + case "MATCH", "SUB-RULE": return nil, fmt.Errorf("unsupported rule type on rule-set") default: - return parse(tp, payload, target, params) + return parse(tp, payload, target, params, nil) } }} } diff --git a/rules/provider/parse.go b/rules/provider/parse.go index 80311af0..86e21a30 100644 --- a/rules/provider/parse.go +++ b/rules/provider/parse.go @@ -17,7 +17,7 @@ type ruleProviderSchema struct { Interval int `provider:"interval,omitempty"` } -func ParseRuleProvider(name string, mapping map[string]interface{}, parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) (P.RuleProvider, error) { +func ParseRuleProvider(name string, mapping map[string]interface{}, parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) (P.RuleProvider, error) { schema := &ruleProviderSchema{} decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true}) if err := decoder.Decode(mapping, schema); err != nil { diff --git a/rules/provider/provider.go b/rules/provider/provider.go index ce96c04f..9ae125fb 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -103,7 +103,7 @@ func (rp *ruleSetProvider) MarshalJSON() ([]byte, error) { } func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration, vehicle P.Vehicle, - parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) P.RuleProvider { + parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) P.RuleProvider { rp := &ruleSetProvider{ behavior: behavior, } @@ -126,7 +126,7 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration return wrapper } -func newStrategy(behavior P.RuleType, parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) ruleStrategy { +func newStrategy(behavior P.RuleType, parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) ruleStrategy { switch behavior { case P.Domain: strategy := NewDomainStrategy() diff --git a/rules/provider/rule_set.go b/rules/provider/rule_set.go index 84aaf0fb..326e3b0d 100644 --- a/rules/provider/rule_set.go +++ b/rules/provider/rule_set.go @@ -23,8 +23,8 @@ func (rs *RuleSet) RuleType() C.RuleType { return C.RuleSet } -func (rs *RuleSet) Match(metadata *C.Metadata) bool { - return rs.getProviders().Match(metadata) +func (rs *RuleSet) Match(metadata *C.Metadata) (bool, string) { + return rs.getProviders().Match(metadata), rs.adapter } func (rs *RuleSet) Adapter() string { @@ -47,7 +47,7 @@ func (rs *RuleSet) getProviders() P.RuleProvider { return rs.ruleProvider } -func NewRuleSet(ruleProviderName string, adapter string, noResolveIP bool, parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error)) (*RuleSet, error) { +func NewRuleSet(ruleProviderName string, adapter string, noResolveIP bool) (*RuleSet, error) { rp, ok := RuleProviders()[ruleProviderName] if !ok { return nil, fmt.Errorf("rule set %s not found", ruleProviderName) diff --git a/rules/sub_rule/sub_rules.go b/rules/sub_rule/sub_rules.go new file mode 100644 index 00000000..2b452b5f --- /dev/null +++ b/rules/sub_rule/sub_rules.go @@ -0,0 +1,91 @@ +package sub_rule + +import ( + "fmt" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/rules/common" + "github.com/Dreamacro/clash/rules/logic" +) + +type SubRule struct { + *common.Base + payload string + payloadRule C.Rule + subName string + subRules *map[string][]C.Rule + shouldFindProcess *bool + shouldResolveIP *bool +} + +func NewSubRule(payload, subName string, sub *map[string][]C.Rule, + parse func(tp, payload, target string, params []string, subRules *map[string][]C.Rule) (parsed C.Rule, parseErr error)) (*SubRule, error) { + payloadRule, err := logic.ParseRuleByPayload(fmt.Sprintf("(%s)", payload), parse) + if err != nil { + return nil, err + } + if len(payloadRule) != 1 { + return nil, fmt.Errorf("Sub-Rule rule must contain one rule") + } + + return &SubRule{ + Base: &common.Base{}, + payload: payload, + payloadRule: payloadRule[0], + subName: subName, + subRules: sub, + }, nil +} + +func (r *SubRule) RuleType() C.RuleType { + return C.SubRules +} + +func (r *SubRule) Match(metadata *C.Metadata) (bool, string) { + + return match(metadata, r.subName, r.subRules) +} + +func match(metadata *C.Metadata, name string, subRules *map[string][]C.Rule) (bool, string) { + for _, rule := range (*subRules)[name] { + if m, a := rule.Match(metadata); m { + if rule.RuleType() == C.SubRules { + match(metadata, rule.Adapter(), subRules) + } else { + return m, a + } + } + } + return false, "" +} + +func (r *SubRule) ShouldResolveIP() bool { + if r.shouldResolveIP == nil { + s := false + for _, rule := range (*r.subRules)[r.subName] { + s = s || rule.ShouldResolveIP() + } + r.shouldResolveIP = &s + } + + return *r.shouldResolveIP +} + +func (r *SubRule) ShouldFindProcess() bool { + if r.shouldFindProcess == nil { + s := false + for _, rule := range (*r.subRules)[r.subName] { + s = s || rule.ShouldFindProcess() + } + r.shouldFindProcess = &s + } + + return *r.shouldFindProcess +} + +func (r *SubRule) Adapter() string { + return r.subName +} + +func (r *SubRule) Payload() string { + return r.payload +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 3c74e14d..b83da180 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -410,8 +410,8 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { } } - if rule.Match(metadata) { - adapter, ok := proxies[rule.Adapter()] + if matched, ada := rule.Match(metadata); matched { + adapter, ok := proxies[ada] if !ok { continue }