diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 19875b8c..7c021c87 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -12,21 +12,18 @@ import ( "strconv" "strings" "sync" - "time" - "github.com/metacubex/mihomo/common/atomic" CN "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/resolver" + "github.com/metacubex/mihomo/component/slowdown" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/dns" "github.com/metacubex/mihomo/log" wireguard "github.com/metacubex/sing-wireguard" - "github.com/jpillora/backoff" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" @@ -129,48 +126,6 @@ func (option WireGuardPeerOption) Prefixes() ([]netip.Prefix, error) { return localPrefixes, nil } -type wgSingDialer struct { - proxydialer.SingDialer - errTimes atomic.Int64 - backoff *backoff.Backoff -} - -func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if d.errTimes.Load() > 10 { - select { - case <-time.After(d.backoff.Duration()): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - c, err := d.SingDialer.DialContext(ctx, network, destination) - if err != nil { - d.errTimes.Add(1) - return nil, err - } - d.errTimes.Store(0) - d.backoff.Reset() - return c, nil -} - -func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - if d.errTimes.Load() > 10 { - select { - case <-time.After(d.backoff.Duration()): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - c, err := d.SingDialer.ListenPacket(ctx, destination) - if err != nil { - d.errTimes.Add(1) - return nil, err - } - d.errTimes.Store(0) - d.backoff.Reset() - return c, nil -} - func NewWireGuard(option WireGuardOption) (*WireGuard, error) { outbound := &WireGuard{ Base: &Base{ @@ -182,16 +137,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, - dialer: &wgSingDialer{ - SingDialer: proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), - errTimes: atomic.NewInt64(0), - backoff: &backoff.Backoff{ - Min: 10 * time.Millisecond, - Max: 1 * time.Second, - Factor: 2, - Jitter: true, - }, - }, + dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()), } runtime.SetFinalizer(outbound, closeWireGuard) diff --git a/component/proxydialer/slowdown.go b/component/proxydialer/slowdown.go new file mode 100644 index 00000000..c62fc344 --- /dev/null +++ b/component/proxydialer/slowdown.go @@ -0,0 +1,34 @@ +package proxydialer + +import ( + "context" + "net" + "net/netip" + + "github.com/metacubex/mihomo/component/slowdown" + C "github.com/metacubex/mihomo/constant" +) + +type SlowDownDialer struct { + C.Dialer + Slowdown *slowdown.SlowDown +} + +func (d SlowDownDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return slowdown.Do(d.Slowdown, ctx, func() (net.Conn, error) { + return d.Dialer.DialContext(ctx, network, address) + }) +} + +func (d SlowDownDialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { + return slowdown.Do(d.Slowdown, ctx, func() (net.PacketConn, error) { + return d.Dialer.ListenPacket(ctx, network, address, rAddrPort) + }) +} + +func NewSlowDownDialer(d C.Dialer, sd *slowdown.SlowDown) SlowDownDialer { + return SlowDownDialer{ + Dialer: d, + Slowdown: sd, + } +} diff --git a/component/proxydialer/slowdown_sing.go b/component/proxydialer/slowdown_sing.go new file mode 100644 index 00000000..cc3a46aa --- /dev/null +++ b/component/proxydialer/slowdown_sing.go @@ -0,0 +1,33 @@ +package proxydialer + +import ( + "context" + "net" + + "github.com/metacubex/mihomo/component/slowdown" + M "github.com/sagernet/sing/common/metadata" +) + +type SlowDownSingDialer struct { + SingDialer + Slowdown *slowdown.SlowDown +} + +func (d SlowDownSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + return slowdown.Do(d.Slowdown, ctx, func() (net.Conn, error) { + return d.SingDialer.DialContext(ctx, network, destination) + }) +} + +func (d SlowDownSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return slowdown.Do(d.Slowdown, ctx, func() (net.PacketConn, error) { + return d.SingDialer.ListenPacket(ctx, destination) + }) +} + +func NewSlowDownSingDialer(d SingDialer, sd *slowdown.SlowDown) SlowDownSingDialer { + return SlowDownSingDialer{ + SingDialer: d, + Slowdown: sd, + } +} diff --git a/component/slowdown/backoff.go b/component/slowdown/backoff.go new file mode 100644 index 00000000..a5a7251c --- /dev/null +++ b/component/slowdown/backoff.go @@ -0,0 +1,101 @@ +// modify from https://github.com/jpillora/backoff/blob/v1.0.0/backoff.go + +package slowdown + +import ( + "math" + "math/rand" + "sync/atomic" + "time" +) + +// Backoff is a time.Duration counter, starting at Min. After every call to +// the Duration method the current timing is multiplied by Factor, but it +// never exceeds Max. +// +// Backoff is not generally concurrent-safe, but the ForAttempt method can +// be used concurrently. +type Backoff struct { + attempt atomic.Uint64 + // Factor is the multiplying factor for each increment step + Factor float64 + // Jitter eases contention by randomizing backoff steps + Jitter bool + // Min and Max are the minimum and maximum values of the counter + Min, Max time.Duration +} + +// Duration returns the duration for the current attempt before incrementing +// the attempt counter. See ForAttempt. +func (b *Backoff) Duration() time.Duration { + d := b.ForAttempt(float64(b.attempt.Add(1) - 1)) + return d +} + +const maxInt64 = float64(math.MaxInt64 - 512) + +// ForAttempt returns the duration for a specific attempt. This is useful if +// you have a large number of independent Backoffs, but don't want use +// unnecessary memory storing the Backoff parameters per Backoff. The first +// attempt should be 0. +// +// ForAttempt is concurrent-safe. +func (b *Backoff) ForAttempt(attempt float64) time.Duration { + // Zero-values are nonsensical, so we use + // them to apply defaults + min := b.Min + if min <= 0 { + min = 100 * time.Millisecond + } + max := b.Max + if max <= 0 { + max = 10 * time.Second + } + if min >= max { + // short-circuit + return max + } + factor := b.Factor + if factor <= 0 { + factor = 2 + } + //calculate this duration + minf := float64(min) + durf := minf * math.Pow(factor, attempt) + if b.Jitter { + durf = rand.Float64()*(durf-minf) + minf + } + //ensure float64 wont overflow int64 + if durf > maxInt64 { + return max + } + dur := time.Duration(durf) + //keep within bounds + if dur < min { + return min + } + if dur > max { + return max + } + return dur +} + +// Reset restarts the current attempt counter at zero. +func (b *Backoff) Reset() { + b.attempt.Store(0) +} + +// Attempt returns the current attempt counter value. +func (b *Backoff) Attempt() float64 { + return float64(b.attempt.Load()) +} + +// Copy returns a backoff with equals constraints as the original +func (b *Backoff) Copy() *Backoff { + return &Backoff{ + Factor: b.Factor, + Jitter: b.Jitter, + Min: b.Min, + Max: b.Max, + } +} diff --git a/component/slowdown/slowdown.go b/component/slowdown/slowdown.go new file mode 100644 index 00000000..3fc12191 --- /dev/null +++ b/component/slowdown/slowdown.go @@ -0,0 +1,49 @@ +package slowdown + +import ( + "context" + "sync/atomic" + "time" +) + +type SlowDown struct { + errTimes atomic.Int64 + backoff Backoff +} + +func (s *SlowDown) Wait(ctx context.Context) (err error) { + select { + case <-time.After(s.backoff.Duration()): + case <-ctx.Done(): + err = ctx.Err() + } + return +} + +func New() *SlowDown { + return &SlowDown{ + backoff: Backoff{ + Min: 10 * time.Millisecond, + Max: 1 * time.Second, + Factor: 2, + Jitter: true, + }, + } +} + +func Do[T any](s *SlowDown, ctx context.Context, fn func() (T, error)) (t T, err error) { + if s.errTimes.Load() > 10 { + err = s.Wait(ctx) + if err != nil { + return + } + } + t, err = fn() + if err != nil { + s.errTimes.Add(1) + return + } + s.errTimes.Store(0) + s.backoff.Reset() + return +} diff --git a/go.mod b/go.mod index c2571ca2..54a0a5dc 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/gobwas/ws v1.3.2 github.com/gofrs/uuid/v5 v5.0.0 github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 - github.com/jpillora/backoff v1.0.0 github.com/klauspost/cpuid/v2 v2.2.6 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/mdlayher/netlink v1.7.2 diff --git a/go.sum b/go.sum index 8b821f53..a7f6cd02 100644 --- a/go.sum +++ b/go.sum @@ -84,8 +84,6 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= -github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index cba36d9c..f8fdcf11 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -11,12 +11,11 @@ import ( "sync" "time" - "github.com/jpillora/backoff" - N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/nat" P "github.com/metacubex/mihomo/component/process" "github.com/metacubex/mihomo/component/resolver" + "github.com/metacubex/mihomo/component/slowdown" "github.com/metacubex/mihomo/component/sniffer" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/constant/features" @@ -699,12 +698,7 @@ func shouldStopRetry(err error) bool { } func retry[T any](ctx context.Context, ft func(context.Context) (T, error), fe func(err error)) (t T, err error) { - b := &backoff.Backoff{ - Min: 10 * time.Millisecond, - Max: 1 * time.Second, - Factor: 2, - Jitter: true, - } + s := slowdown.New() for i := 0; i < 10; i++ { t, err = ft(ctx) if err != nil { @@ -714,10 +708,9 @@ func retry[T any](ctx context.Context, ft func(context.Context) (T, error), fe f if shouldStopRetry(err) { return } - select { - case <-time.After(b.Duration()): + if s.Wait(ctx) == nil { continue - case <-ctx.Done(): + } else { return } } else {