chore: dual stack fallback

This commit is contained in:
Skyxim 2023-02-26 10:42:22 +08:00
parent cdd91f5132
commit 8df44a7c3d

View file

@ -8,6 +8,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
"time"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
) )
@ -19,6 +20,7 @@ var (
tcpConcurrent = false tcpConcurrent = false
ErrorInvalidedNetworkStack = errors.New("invalided network stack") ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorConnTimeout = errors.New("connect timeout") ErrorConnTimeout = errors.New("connect timeout")
fallbackTimeout = 300 * time.Millisecond
) )
func applyOptions(options ...Option) *option { func applyOptions(options ...Option) *option {
@ -131,9 +133,10 @@ func serialDualStackDialContext(ctx context.Context, network, address string, op
if err != nil { if err != nil {
return nil, err return nil, err
} }
return dualStackDial( return dualStackDialContext(
func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, ctx,
func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) },
func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) },
opt.prefer == 4) opt.prefer == 4)
} }
@ -159,9 +162,14 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string
return parallelDialContext(ctx, network, ips, port, opt) return parallelDialContext(ctx, network, ips, port, opt)
} }
ipv4s, ipv6s := sortationAddr(ips) ipv4s, ipv6s := sortationAddr(ips)
return dualStackDial( return dualStackDialContext(
func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv4s, port, opt) }, ctx,
func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv6s, port, opt) }, func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv4s, port, opt)
},
func(ctx context.Context) (net.Conn, error) {
return parallelDialContext(ctx, network, ipv6s, port, opt)
},
opt.prefer == 4) opt.prefer == 4)
} }
@ -182,14 +190,19 @@ func NewDialer(options ...Option) Dialer {
return Dialer{Opt: *opt} return Dialer{Opt: *opt}
} }
func dualStackDial( func dualStackDialContext(
ipv4DialFn func() (net.Conn, error), ctx context.Context,
ipv6DialFn func() (net.Conn, error), ipv4DialFn func(ctx context.Context) (net.Conn, error),
ipv6DialFn func(ctx context.Context) (net.Conn, error),
preferIPv4 bool) (net.Conn, error) { preferIPv4 bool) (net.Conn, error) {
fallbackTimer := time.NewTimer(fallbackTimeout)
defer fallbackTimer.Stop()
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
results := make(chan dialResult) results := make(chan dialResult)
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)
racer := func(dial func() (net.Conn, error), isPrimary bool) { racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) {
result := dialResult{isPrimary: isPrimary} result := dialResult{isPrimary: isPrimary}
defer func() { defer func() {
select { select {
@ -200,29 +213,33 @@ func dualStackDial(
} }
} }
}() }()
result.Conn, result.error = dial() result.Conn, result.error = dial(fallbackCtx)
} }
go racer(ipv4DialFn, preferIPv4) go racer(ipv4DialFn, preferIPv4)
go racer(ipv6DialFn, !preferIPv4) go racer(ipv6DialFn, !preferIPv4)
var fallbackErr dialResult var fallback dialResult
var primaryErr dialResult var err error
for res := range results { for {
if res.error == nil { select {
if res.isPrimary { case <-ctx.Done():
return res.Conn, nil if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil
} }
fallbackErr = res return nil, fmt.Errorf("dual stack connect failed: %w", err)
} case <-fallbackTimer.C:
if res.isPrimary { if fallback.error == nil && fallback.Conn != nil {
primaryErr = res return fallback.Conn, nil
} else { }
fallbackErr = res case res := <-results:
if res.error == nil {
if res.isPrimary {
return res.Conn, nil
}
fallback = res
}
err = res.error
} }
} }
if primaryErr.error != nil {
return nil, primaryErr.error
}
return nil, fallbackErr.error
} }
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {