diff --git a/dns/doh.go b/dns/doh.go index 8c38e2f4..f34246d5 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -208,14 +208,13 @@ func (doh *dnsOverHTTPS) exchangeHTTPSClient( } doh.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) - httpReq, err := http.NewRequest(method, doh.url.String(), nil) + httpReq, err := http.NewRequestWithContext(ctx, method, doh.url.String(), nil) if err != nil { return nil, fmt.Errorf("creating http request to %s: %w", doh.url, err) } httpReq.Header.Set("Accept", "application/dns-message") httpReq.Header.Set("User-Agent", "") - _ = httpReq.WithContext(ctx) httpResp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("requesting %s: %w", doh.url, err) diff --git a/dns/util.go b/dns/util.go index 5bf09b8f..90693eae 100644 --- a/dns/util.go +++ b/dns/util.go @@ -216,7 +216,7 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout) for _, client := range clients { r := client - fast.Go(func() (*D.Msg, error) { + fn := func() (*D.Msg, error) { m, err := r.ExchangeContext(ctx, m) if err != nil { return nil, err @@ -224,6 +224,27 @@ func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.M return nil, errors.New("server failure") } return m, nil + } + fast.Go(func() (*D.Msg, error) { + ch := make(chan result, 1) + go func() { + m, err := fn() + ch <- result{ + Msg: m, + Error: err, + } + }() + select { + case r := <-ch: + return r.Msg, r.Error + case <-ctx.Done(): + select { + case r := <-ch: + return r.Msg, r.Error + default: + return nil, ctx.Err() + } + } }) }