From 8df44a7c3d926d9f4f5640df49aed5ad19c83c35 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 26 Feb 2023 10:42:22 +0800 Subject: [PATCH] chore: dual stack fallback --- component/dialer/dialer.go | 71 +++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index e6e8edc9..1a6c9d43 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -8,6 +8,7 @@ import ( "net/netip" "strings" "sync" + "time" "github.com/Dreamacro/clash/component/resolver" ) @@ -19,6 +20,7 @@ var ( tcpConcurrent = false ErrorInvalidedNetworkStack = errors.New("invalided network stack") ErrorConnTimeout = errors.New("connect timeout") + fallbackTimeout = 300 * time.Millisecond ) func applyOptions(options ...Option) *option { @@ -131,9 +133,10 @@ func serialDualStackDialContext(ctx context.Context, network, address string, op if err != nil { return nil, err } - return dualStackDial( - func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, - func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + return dualStackDialContext( + ctx, + func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + func(ctx context.Context) (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, opt.prefer == 4) } @@ -159,9 +162,14 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string return parallelDialContext(ctx, network, ips, port, opt) } ipv4s, ipv6s := sortationAddr(ips) - return dualStackDial( - func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv4s, port, opt) }, - func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv6s, port, opt) }, + return dualStackDialContext( + ctx, + func(ctx context.Context) (net.Conn, error) { + return parallelDialContext(ctx, network, ipv4s, port, opt) + }, + func(ctx context.Context) (net.Conn, error) { + return parallelDialContext(ctx, network, ipv6s, port, opt) + }, opt.prefer == 4) } @@ -182,14 +190,19 @@ func NewDialer(options ...Option) Dialer { return Dialer{Opt: *opt} } -func dualStackDial( - ipv4DialFn func() (net.Conn, error), - ipv6DialFn func() (net.Conn, error), +func dualStackDialContext( + ctx context.Context, + ipv4DialFn func(ctx context.Context) (net.Conn, error), + ipv6DialFn func(ctx context.Context) (net.Conn, error), preferIPv4 bool) (net.Conn, error) { + fallbackTimer := time.NewTimer(fallbackTimeout) + defer fallbackTimer.Stop() + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() results := make(chan dialResult) returned := make(chan struct{}) defer close(returned) - racer := func(dial func() (net.Conn, error), isPrimary bool) { + racer := func(dial func(ctx context.Context) (net.Conn, error), isPrimary bool) { result := dialResult{isPrimary: isPrimary} defer func() { select { @@ -200,29 +213,33 @@ func dualStackDial( } } }() - result.Conn, result.error = dial() + result.Conn, result.error = dial(fallbackCtx) } go racer(ipv4DialFn, preferIPv4) go racer(ipv6DialFn, !preferIPv4) - var fallbackErr dialResult - var primaryErr dialResult - for res := range results { - if res.error == nil { - if res.isPrimary { - return res.Conn, nil + var fallback dialResult + var err error + for { + select { + case <-ctx.Done(): + if fallback.error == nil && fallback.Conn != nil { + return fallback.Conn, nil } - fallbackErr = res - } - if res.isPrimary { - primaryErr = res - } else { - fallbackErr = res + return nil, fmt.Errorf("dual stack connect failed: %w", err) + case <-fallbackTimer.C: + if fallback.error == nil && fallback.Conn != nil { + return fallback.Conn, nil + } + case res := <-results: + if res.error == nil { + if res.isPrimary { + return res.Conn, nil + } + fallback = res + } + err = res.error } } - if primaryErr.error != nil { - return nil, primaryErr.error - } - return nil, fallbackErr.error } func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {