diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 479def67..d70e9173 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -182,13 +182,14 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, go racer(ipv6s, preferIPVersion != 4) var fallback dialResult var errs []error - for i := 0; i < 2; i++ { + for i := 0; i < 2; { select { case <-fallbackTicker.C: if fallback.error == nil && fallback.Conn != nil { return fallback.Conn, nil } case res := <-results: + i++ if res.error == nil { if res.isPrimary { return res.Conn, nil @@ -217,7 +218,7 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, returned := make(chan struct{}) defer close(returned) racer := func(ctx context.Context, ip netip.Addr) { - result := dialResult{isPrimary: true} + result := dialResult{isPrimary: true, ip: ip} defer func() { select { case results <- result: @@ -227,7 +228,6 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, } } }() - result.ip = ip result.Conn, result.error = dialContext(ctx, network, ip, port, opt) } @@ -235,23 +235,18 @@ func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, go racer(ctx, ip) } var errs []error - for { - select { - case <-ctx.Done(): - if len(errs) > 0 { - return nil, errorsJoin(errs...) - } - if ctx.Err() == context.DeadlineExceeded { - return nil, os.ErrDeadlineExceeded - } - return nil, ctx.Err() - case res := <-results: - if res.error == nil { - return res.Conn, nil - } - errs = append(errs, res.error) + for i := 0; i < len(ips); i++ { + res := <-results + if res.error == nil { + return res.Conn, nil } + errs = append(errs, res.error) } + + if len(errs) > 0 { + return nil, errorsJoin(errs...) + } + return nil, os.ErrDeadlineExceeded } func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {