From 2334bafe68017ac421abe3d0352efe0b665898f2 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Tue, 10 Dec 2019 15:04:22 +0800 Subject: [PATCH] Change: proxy gruop strategy improvement --- adapters/outboundgroup/common.go | 20 +++++++ adapters/outboundgroup/fallback.go | 25 ++++---- adapters/outboundgroup/loadbalance.go | 13 +++-- adapters/outboundgroup/selector.go | 13 +++-- adapters/outboundgroup/urltest.go | 84 ++++++++++++++------------- common/singledo/singledo.go | 54 +++++++++++++++++ common/singledo/singledo_test.go | 52 +++++++++++++++++ 7 files changed, 198 insertions(+), 63 deletions(-) create mode 100644 adapters/outboundgroup/common.go create mode 100644 common/singledo/singledo.go create mode 100644 common/singledo/singledo_test.go diff --git a/adapters/outboundgroup/common.go b/adapters/outboundgroup/common.go new file mode 100644 index 00000000..ce40072c --- /dev/null +++ b/adapters/outboundgroup/common.go @@ -0,0 +1,20 @@ +package outboundgroup + +import ( + "time" + + "github.com/Dreamacro/clash/adapters/provider" + C "github.com/Dreamacro/clash/constant" +) + +const ( + defaultGetProxiesDuration = time.Second * 5 +) + +func getProvidersProxies(providers []provider.ProxyProvider) []C.Proxy { + proxies := []C.Proxy{} + for _, provider := range providers { + proxies = append(proxies, provider.Proxies()...) + } + return proxies +} diff --git a/adapters/outboundgroup/fallback.go b/adapters/outboundgroup/fallback.go index 7c0525b8..2104c39e 100644 --- a/adapters/outboundgroup/fallback.go +++ b/adapters/outboundgroup/fallback.go @@ -7,11 +7,13 @@ import ( "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" + "github.com/Dreamacro/clash/common/singledo" C "github.com/Dreamacro/clash/constant" ) type Fallback struct { *outbound.Base + single *singledo.Single providers []provider.ProxyProvider } @@ -56,29 +58,28 @@ func (f *Fallback) MarshalJSON() ([]byte, error) { } func (f *Fallback) proxies() []C.Proxy { - proxies := []C.Proxy{} - for _, provider := range f.providers { - proxies = append(proxies, provider.Proxies()...) - } - return proxies + elm, _, _ := f.single.Do(func() (interface{}, error) { + return getProvidersProxies(f.providers), nil + }) + + return elm.([]C.Proxy) } func (f *Fallback) findAliveProxy() C.Proxy { - for _, provider := range f.providers { - proxies := provider.Proxies() - for _, proxy := range proxies { - if proxy.Alive() { - return proxy - } + proxies := f.proxies() + for _, proxy := range proxies { + if proxy.Alive() { + return proxy } } - return f.providers[0].Proxies()[0] + return f.proxies()[0] } func NewFallback(name string, providers []provider.ProxyProvider) *Fallback { return &Fallback{ Base: outbound.NewBase(name, C.Fallback, false), + single: singledo.NewSingle(defaultGetProxiesDuration), providers: providers, } } diff --git a/adapters/outboundgroup/loadbalance.go b/adapters/outboundgroup/loadbalance.go index 74a154ce..78a942e0 100644 --- a/adapters/outboundgroup/loadbalance.go +++ b/adapters/outboundgroup/loadbalance.go @@ -8,6 +8,7 @@ import ( "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/common/murmur3" + "github.com/Dreamacro/clash/common/singledo" C "github.com/Dreamacro/clash/constant" "golang.org/x/net/publicsuffix" @@ -15,6 +16,7 @@ import ( type LoadBalance struct { *outbound.Base + single *singledo.Single maxRetry int providers []provider.ProxyProvider } @@ -98,11 +100,11 @@ func (lb *LoadBalance) SupportUDP() bool { } func (lb *LoadBalance) proxies() []C.Proxy { - proxies := []C.Proxy{} - for _, provider := range lb.providers { - proxies = append(proxies, provider.Proxies()...) - } - return proxies + elm, _, _ := lb.single.Do(func() (interface{}, error) { + return getProvidersProxies(lb.providers), nil + }) + + return elm.([]C.Proxy) } func (lb *LoadBalance) MarshalJSON() ([]byte, error) { @@ -119,6 +121,7 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) { func NewLoadBalance(name string, providers []provider.ProxyProvider) *LoadBalance { return &LoadBalance{ Base: outbound.NewBase(name, C.LoadBalance, false), + single: singledo.NewSingle(defaultGetProxiesDuration), maxRetry: 3, providers: providers, } diff --git a/adapters/outboundgroup/selector.go b/adapters/outboundgroup/selector.go index fc53be71..fd9ef041 100644 --- a/adapters/outboundgroup/selector.go +++ b/adapters/outboundgroup/selector.go @@ -8,11 +8,13 @@ import ( "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" + "github.com/Dreamacro/clash/common/singledo" C "github.com/Dreamacro/clash/constant" ) type Selector struct { *outbound.Base + single *singledo.Single selected C.Proxy providers []provider.ProxyProvider } @@ -66,17 +68,18 @@ func (s *Selector) Set(name string) error { } func (s *Selector) proxies() []C.Proxy { - proxies := []C.Proxy{} - for _, provider := range s.providers { - proxies = append(proxies, provider.Proxies()...) - } - return proxies + elm, _, _ := s.single.Do(func() (interface{}, error) { + return getProvidersProxies(s.providers), nil + }) + + return elm.([]C.Proxy) } func NewSelector(name string, providers []provider.ProxyProvider) *Selector { selected := providers[0].Proxies()[0] return &Selector{ Base: outbound.NewBase(name, C.Selector, false), + single: singledo.NewSingle(defaultGetProxiesDuration), providers: providers, selected: selected, } diff --git a/adapters/outboundgroup/urltest.go b/adapters/outboundgroup/urltest.go index 2e57e0f2..cf1ad138 100644 --- a/adapters/outboundgroup/urltest.go +++ b/adapters/outboundgroup/urltest.go @@ -4,36 +4,35 @@ import ( "context" "encoding/json" "net" + "time" "github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/provider" + "github.com/Dreamacro/clash/common/singledo" C "github.com/Dreamacro/clash/constant" ) type URLTest struct { *outbound.Base - fast C.Proxy - providers []provider.ProxyProvider + single *singledo.Single + fastSingle *singledo.Single + providers []provider.ProxyProvider } func (u *URLTest) Now() string { - return u.fast.Name() + return u.fast().Name() } func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { - for i := 0; i < 3; i++ { - c, err = u.fast.DialContext(ctx, metadata) - if err == nil { - c.AppendToChains(u) - return - } - u.fallback() + c, err = u.fast().DialContext(ctx, metadata) + if err == nil { + c.AppendToChains(u) } - return + return c, err } func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - pc, addr, err := u.fast.DialUDP(metadata) + pc, addr, err := u.fast().DialUDP(metadata) if err == nil { pc.AppendToChains(u) } @@ -41,15 +40,37 @@ func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) } func (u *URLTest) proxies() []C.Proxy { - proxies := []C.Proxy{} - for _, provider := range u.providers { - proxies = append(proxies, provider.Proxies()...) - } - return proxies + elm, _, _ := u.single.Do(func() (interface{}, error) { + return getProvidersProxies(u.providers), nil + }) + + return elm.([]C.Proxy) +} + +func (u *URLTest) fast() C.Proxy { + elm, _, _ := u.fastSingle.Do(func() (interface{}, error) { + proxies := u.proxies() + fast := proxies[0] + min := fast.LastDelay() + for _, proxy := range proxies[1:] { + if !proxy.Alive() { + continue + } + + delay := proxy.LastDelay() + if delay < min { + fast = proxy + min = delay + } + } + return fast, nil + }) + + return elm.(C.Proxy) } func (u *URLTest) SupportUDP() bool { - return u.fast.SupportUDP() + return u.fast().SupportUDP() } func (u *URLTest) MarshalJSON() ([]byte, error) { @@ -64,30 +85,11 @@ func (u *URLTest) MarshalJSON() ([]byte, error) { }) } -func (u *URLTest) fallback() { - proxies := u.proxies() - fast := proxies[0] - min := fast.LastDelay() - for _, proxy := range proxies[1:] { - if !proxy.Alive() { - continue - } - - delay := proxy.LastDelay() - if delay < min { - fast = proxy - min = delay - } - } - u.fast = fast -} - func NewURLTest(name string, providers []provider.ProxyProvider) *URLTest { - fast := providers[0].Proxies()[0] - return &URLTest{ - Base: outbound.NewBase(name, C.URLTest, false), - fast: fast, - providers: providers, + Base: outbound.NewBase(name, C.URLTest, false), + single: singledo.NewSingle(defaultGetProxiesDuration), + fastSingle: singledo.NewSingle(time.Second * 10), + providers: providers, } } diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go new file mode 100644 index 00000000..4828d558 --- /dev/null +++ b/common/singledo/singledo.go @@ -0,0 +1,54 @@ +package singledo + +import ( + "sync" + "time" +) + +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +type Single struct { + mux sync.Mutex + last int64 + wait int64 + call *call + result *Result +} + +type Result struct { + Val interface{} + Err error +} + +func (s *Single) Do(fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + s.mux.Lock() + now := time.Now().Unix() + if now < s.last+s.wait { + s.mux.Unlock() + return s.result.Val, s.result.Err, true + } + + if call := s.call; call != nil { + s.mux.Unlock() + call.wg.Wait() + return call.val, call.err, true + } + + call := &call{} + call.wg.Add(1) + s.call = call + s.mux.Unlock() + call.val, call.err = fn() + s.call = nil + s.result = &Result{call.val, call.err} + s.last = now + return call.val, call.err, false +} + +func NewSingle(wait time.Duration) *Single { + return &Single{wait: int64(wait)} +} diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go new file mode 100644 index 00000000..d6552580 --- /dev/null +++ b/common/singledo/singledo_test.go @@ -0,0 +1,52 @@ +package singledo + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBasic(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + shardCount := 0 + call := func() (interface{}, error) { + foo++ + return nil, nil + } + + var wg sync.WaitGroup + const n = 10 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + _, _, shard := single.Do(call) + if shard { + shardCount++ + } + wg.Done() + }() + } + + wg.Wait() + assert.Equal(t, 1, foo) + assert.Equal(t, 9, shardCount) +} + +func TestTimer(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + call := func() (interface{}, error) { + foo++ + return nil, nil + } + + single.Do(call) + time.Sleep(10 * time.Millisecond) + _, _, shard := single.Do(call) + + assert.Equal(t, 1, foo) + assert.True(t, shard) +}