diff --git a/config/config.go b/config/config.go index 57bf55e8..f1a85969 100644 --- a/config/config.go +++ b/config/config.go @@ -863,7 +863,7 @@ func hostWithDefaultPort(host string, defPort string) (string, error) { return net.JoinHostPort(hostname, port), nil } -func parseNameServer(servers []string) ([]dns.NameServer, error) { +func parseNameServer(servers []string, preferH3 bool) ([]dns.NameServer, error) { var nameservers []dns.NameServer for idx, server := range servers { @@ -889,7 +889,15 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { addr, err = hostWithDefaultPort(u.Host, "853") dnsNetType = "tcp-tls" // DNS over TLS case "https": - clearURL := url.URL{Scheme: "https", Host: u.Host, Path: u.Path} + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil && strings.Contains(err.Error(), "missing port in address") { + host = net.JoinHostPort(host, "443") + } else { + if err!=nil{ + return nil,err + } + } + clearURL := url.URL{Scheme: "https", Host: host, Path: u.Path} addr = clearURL.String() dnsNetType = "https" // DNS over HTTPS if len(u.Fragment) != 0 { @@ -928,17 +936,18 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) { ProxyAdapter: proxyAdapter, Interface: dialer.DefaultInterface, Params: params, + PreferH3: preferH3, }, ) } return nameservers, nil } -func parseNameServerPolicy(nsPolicy map[string]string) (map[string]dns.NameServer, error) { +func parseNameServerPolicy(nsPolicy map[string]string, preferH3 bool) (map[string]dns.NameServer, error) { policy := map[string]dns.NameServer{} for domain, server := range nsPolicy { - nameservers, err := parseNameServer([]string{server}) + nameservers, err := parseNameServer([]string{server}, preferH3) if err != nil { return nil, err } @@ -1018,26 +1027,26 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R }, } var err error - if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer); err != nil { + if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer, cfg.PreferH3); err != nil { return nil, err } - if dnsCfg.Fallback, err = parseNameServer(cfg.Fallback); err != nil { + if dnsCfg.Fallback, err = parseNameServer(cfg.Fallback, cfg.PreferH3); err != nil { return nil, err } - if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy); err != nil { + if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, cfg.PreferH3); err != nil { return nil, err } - if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil { + if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver, cfg.PreferH3); err != nil { return nil, err } if len(cfg.DefaultNameserver) == 0 { return nil, errors.New("default nameserver should have at least one nameserver") } - if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver); err != nil { + if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver, cfg.PreferH3); err != nil { return nil, err } // check default nameserver is pure ip addr diff --git a/constant/dns.go b/constant/dns.go index be8b4a17..da68753c 100644 --- a/constant/dns.go +++ b/constant/dns.go @@ -114,3 +114,14 @@ func NewDNSPrefer(prefer string) DNSPrefer { return DualStack } } + +type HTTPVersion string + +const ( + // HTTPVersion11 is HTTP/1.1. + HTTPVersion11 HTTPVersion = "http/1.1" + // HTTPVersion2 is HTTP/2. + HTTPVersion2 HTTPVersion = "h2" + // HTTPVersion3 is HTTP/3. + HTTPVersion3 HTTPVersion = "h3" +) \ No newline at end of file diff --git a/dns/doh.go b/dns/doh.go index 8403f7d1..d5a8b06e 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -1,164 +1,730 @@ package dns import ( - "bytes" "context" "crypto/tls" + "encoding/base64" + "errors" "fmt" - "github.com/Dreamacro/clash/component/dialer" - "github.com/Dreamacro/clash/component/resolver" - tlsC "github.com/Dreamacro/clash/component/tls" - "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/http3" - D "github.com/miekg/dns" "io" "net" "net/http" + "net/url" + "runtime" "strconv" + "sync" + "time" + + "github.com/Dreamacro/clash/component/dialer" + tlsC "github.com/Dreamacro/clash/component/tls" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" + "github.com/miekg/dns" + D "github.com/miekg/dns" + "golang.org/x/net/http2" ) +// Values to configure HTTP and HTTP/2 transport. const ( - // dotMimeType is the DoH mimetype that should be used. - dotMimeType = "application/dns-message" + // transportDefaultReadIdleTimeout is the default timeout for pinging + // idle connections in HTTP/2 transport. + transportDefaultReadIdleTimeout = 30 * time.Second + + // transportDefaultIdleConnTimeout is the default timeout for idle + // connections in HTTP transport. + transportDefaultIdleConnTimeout = 5 * time.Minute + + // dohMaxConnsPerHost controls the maximum number of connections for + // each host. + dohMaxConnsPerHost = 1 + dialTimeout = 10 * time.Second + + // dohMaxIdleConns controls the maximum number of connections being idle + // at the same time. + dohMaxIdleConns = 1 + maxElapsedTime = time.Second * 30 ) -type dohClient struct { - url string - transport http.RoundTripper +var DefaultHTTPVersions = []C.HTTPVersion{C.HTTPVersion11, C.HTTPVersion2} + +// dnsOverHTTPS is a struct that implements the Upstream interface for the +// DNS-over-HTTPS protocol. +type dnsOverHTTPS struct { + // The Client's Transport typically has internal state (cached TCP + // connections), so Clients should be reused instead of created as + // needed. Clients are safe for concurrent use by multiple goroutines. + client *http.Client + clientMu sync.Mutex + + // quicConfig is the QUIC configuration that is used if HTTP/3 is enabled + // for this upstream. + quicConfig *quic.Config + quicConfigGuard sync.Mutex + url *url.URL + r *Resolver + httpVersions []C.HTTPVersion + proxyAdapter string } -func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { - return dc.ExchangeContext(context.Background(), m) +// type check +var _ dnsClient = (*dnsOverHTTPS)(nil) + +// newDoH returns the DNS-over-HTTPS Upstream. +func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[string]string, proxyAdapter string) dnsClient { + u, _ := url.Parse(urlString) + httpVersions := DefaultHTTPVersions + if preferH3 { + httpVersions = append(httpVersions, C.HTTPVersion3) + } + + if params["h3"] == "true" { + httpVersions = []C.HTTPVersion{C.HTTPVersion3} + } + + doh := &dnsOverHTTPS{ + url: u, + r: r, + quicConfig: &quic.Config{ + KeepAlivePeriod: QUICKeepAlivePeriod, + TokenStore: newQUICTokenStore(), + }, + httpVersions: httpVersions, + } + + runtime.SetFinalizer(doh, (*dnsOverHTTPS).Close) + + return doh } -func (dc *dohClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { - // https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 - // In order to maximize cache friendliness, SHOULD use a DNS ID of 0 in every DNS request. - newM := *m - newM.Id = 0 - req, err := dc.newRequest(&newM) +// Address implements the Upstream interface for *dnsOverHTTPS. +func (p *dnsOverHTTPS) Address() string { return p.url.String() } +func (p *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { + // Quote from https://www.rfc-editor.org/rfc/rfc8484.html: + // In order to maximize HTTP cache friendliness, DoH clients using media + // formats that include the ID field from the DNS message header, such + // as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS + // request. + id := m.Id + m.Id = 0 + defer func() { + // Restore the original ID to not break compatibility with proxies. + m.Id = id + if msg != nil { + msg.Id = id + } + }() + + // Check if there was already an active client before sending the request. + // We'll only attempt to re-connect if there was one. + client, isCached, err := p.getClient() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to init http client: %w", err) } - req = req.WithContext(ctx) - msg, err = dc.doRequest(req) - if err == nil { - msg.Id = m.Id - } - return -} + // Make the first attempt to send the DNS query. + msg, err = p.exchangeHTTPS(ctx, client, m) + + // Make up to 2 attempts to re-create the HTTP client and send the request + // again. There are several cases (mostly, with QUIC) where this workaround + // is necessary to make HTTP client usable. We need to make 2 attempts in + // the case when the connection was closed (due to inactivity for example) + // AND the server refuses to open a 0-RTT connection. + for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ { + client, err = p.resetClient(err) + if err != nil { + return nil, fmt.Errorf("failed to reset http client: %w", err) + } + + msg, err = p.exchangeHTTPS(ctx, client, m) + } -// newRequest returns a new DoH request given a dns.Msg. -func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) { - buf, err := m.Pack() if err != nil { - return nil, err + // If the request failed anyway, make sure we don't use this client. + _, resErr := p.resetClient(err) + + return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr) } - req, err := http.NewRequest(http.MethodPost, dc.url, bytes.NewReader(buf)) - if err != nil { - return req, err - } - - req.Header.Set("content-type", dotMimeType) - req.Header.Set("accept", dotMimeType) - return req, nil -} - -func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { - client := &http.Client{Transport: dc.transport} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - - buf, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - msg = &D.Msg{} - err = msg.Unpack(buf) return msg, err } -func newDoHClient(url string, r *Resolver, params map[string]string, proxyAdapter string) *dohClient { - useH3 := params["h3"] == "true" - TLCConfig := tlsC.GetDefaultTLSConfig() - var transport http.RoundTripper - if useH3 { - transport = &http3.RoundTripper{ - Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } +// Exchange implements the Upstream interface for *dnsOverHTTPS. +func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + return p.ExchangeContext(context.Background(), m) +} - ip, err := resolver.ResolveIPWithResolver(host, r) - if err != nil { - return nil, err - } +// Close implements the Upstream interface for *dnsOverHTTPS. +func (p *dnsOverHTTPS) Close() (err error) { + p.clientMu.Lock() + defer p.clientMu.Unlock() - portInt, err := strconv.Atoi(port) - if err != nil { - return nil, err - } + runtime.SetFinalizer(p, nil) - udpAddr := net.UDPAddr{ - IP: net.ParseIP(ip.String()), - Port: portInt, - } + if p.client == nil { + return nil + } - var conn net.PacketConn - if proxyAdapter == "" { - conn, err = dialer.ListenPacket(ctx, "udp", "") - if err != nil { - return nil, err - } - } else { - if wrapConn, err := dialContextExtra(ctx, proxyAdapter, "udp", ip, port); err == nil { - if pc, ok := wrapConn.(*wrapPacketConn); ok { - conn = pc - } else { - return nil, fmt.Errorf("conn isn't wrapPacketConn") - } - } else { - return nil, err - } - } + return p.closeClient(p.client) +} - return quic.DialEarlyContext(ctx, conn, &udpAddr, host, tlsCfg, cfg) - }, - TLSClientConfig: TLCConfig, +// closeClient cleans up resources used by client if necessary. Note, that at +// this point it should only be done for HTTP/3 as it may leak due to keep-alive +// connections. +func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) { + if isHTTP3(client) { + return client.Transport.(io.Closer).Close() + } + + return nil +} + +// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient. +func (p *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) { + resp, err = p.exchangeHTTPSClient(ctx, client, req) + + return resp, err +} + +// exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified +// http.Client instance. +func (p *dnsOverHTTPS) exchangeHTTPSClient( + ctx context.Context, + client *http.Client, + req *dns.Msg, +) (resp *dns.Msg, err error) { + buf, err := req.Pack() + if err != nil { + return nil, fmt.Errorf("packing message: %w", err) + } + + // It appears, that GET requests are more memory-efficient with Golang + // implementation of HTTP/2. + method := http.MethodGet + if isHTTP3(client) { + // If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT. + method = http3.MethodGet0RTT + } + + p.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)) + httpReq, err := http.NewRequest(method, p.url.String(), nil) + if err != nil { + return nil, fmt.Errorf("creating http request to %s: %w", p.url, err) + } + + httpReq.Header.Set("Accept", "application/dns-message") + httpReq.Header.Set("User-Agent", "") + _ = httpReq.WithContext(ctx) + httpResp, err := client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("requesting %s: %w", p.url, err) + } + defer httpResp.Body.Close() + + body, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("reading %s: %w", p.url, err) + } + + if httpResp.StatusCode != http.StatusOK { + return nil, + fmt.Errorf( + "expected status %d, got %d from %s", + http.StatusOK, + httpResp.StatusCode, + p.url, + ) + } + + resp = &dns.Msg{} + err = resp.Unpack(body) + if err != nil { + return nil, fmt.Errorf( + "unpacking response from %s: body is %s: %w", + p.url, + body, + err, + ) + } + + if resp.Id != req.Id { + err = dns.ErrId + } + + return resp, err +} + +// shouldRetry checks what error we have received and returns true if we should +// re-create the HTTP client and retry the request. +func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) { + if err == nil { + return false + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // If this is a timeout error, trying to forcibly re-create the HTTP + // client instance. This is an attempt to fix an issue with DoH client + // stalling after a network change. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/3217. + return true + } + + if isQUICRetryError(err) { + return true + } + + return false +} + +// resetClient triggers re-creation of the *http.Client that is used by this +// upstream. This method accepts the error that caused resetting client as +// depending on the error we may also reset the QUIC config. +func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) { + p.clientMu.Lock() + defer p.clientMu.Unlock() + + if errors.Is(resetErr, quic.Err0RTTRejected) { + // Reset the TokenStore only if 0-RTT was rejected. + p.resetQUICConfig() + } + + oldClient := p.client + if oldClient != nil { + closeErr := p.closeClient(oldClient) + if closeErr != nil { + log.Warnln("warning: failed to close the old http client: %v", closeErr) + } + } + + log.Debugln("re-creating the http client due to %v", resetErr) + p.client, err = p.createClient() + + return p.client, err +} + +// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that +// this method returns a pointer, it is forbidden to change its properties. +func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) { + p.quicConfigGuard.Lock() + defer p.quicConfigGuard.Unlock() + + return p.quicConfig +} + +// resetQUICConfig Re-create the token store to make sure we're not trying to +// use invalid for 0-RTT. +func (p *dnsOverHTTPS) resetQUICConfig() { + p.quicConfigGuard.Lock() + defer p.quicConfigGuard.Unlock() + + p.quicConfig = p.quicConfig.Clone() + p.quicConfig.TokenStore = newQUICTokenStore() +} + +// getClient gets or lazily initializes an HTTP client (and transport) that will +// be used for this DoH resolver. +func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) { + startTime := time.Now() + + p.clientMu.Lock() + defer p.clientMu.Unlock() + if p.client != nil { + return p.client, true, nil + } + + // Timeout can be exceeded while waiting for the lock. This happens quite + // often on mobile devices. + elapsed := time.Since(startTime) + if elapsed > maxElapsedTime { + return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed) + } + + log.Debugln("creating a new http client") + p.client, err = p.createClient() + + return p.client, false, err +} + +// createClient creates a new *http.Client instance. The HTTP protocol version +// will depend on whether HTTP3 is allowed and provided by this upstream. Note, +// that we'll attempt to establish a QUIC connection when creating the client in +// order to check whether HTTP3 is supported. +func (p *dnsOverHTTPS) createClient() (*http.Client, error) { + transport, err := p.createTransport() + if err != nil { + return nil, fmt.Errorf("initializing http transport: %w", err) + } + + client := &http.Client{ + Transport: transport, + Timeout: DefaultTimeout, + Jar: nil, + } + + p.client = client + + return p.client, nil +} + +// createTransport initializes an HTTP transport that will be used specifically +// for this DoH resolver. This HTTP transport ensures that the HTTP requests +// will be sent exactly to the IP address got from the bootstrap resolver. Note, +// that this function will first attempt to establish a QUIC connection (if +// HTTP3 is enabled in the upstream options). If this attempt is successful, +// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport. +func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) { + tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( + &tls.Config{ + InsecureSkipVerify: false, + MinVersion: tls.VersionTLS12, + SessionTicketsDisabled: false, + }) + var nextProtos []string + for _, v := range p.httpVersions { + nextProtos = append(nextProtos, string(v)) + } + tlsConfig.NextProtos = nextProtos + dialContext := getDialHandler(p.r, p.proxyAdapter) + // First, we attempt to create an HTTP3 transport. If the probe QUIC + // connection is established successfully, we'll be using HTTP3 for this + // upstream. + transportH3, err := p.createTransportH3(tlsConfig, dialContext) + if err == nil { + log.Debugln("using HTTP/3 for this upstream: QUIC was faster") + return transportH3, nil + } + + log.Debugln("using HTTP/2 for this upstream: %v", err) + + if !p.supportsHTTP() { + return nil, errors.New("HTTP1/1 and HTTP2 are not supported by this upstream") + } + + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + DisableCompression: true, + DialContext: dialContext, + IdleConnTimeout: transportDefaultIdleConnTimeout, + MaxConnsPerHost: dohMaxConnsPerHost, + MaxIdleConns: dohMaxIdleConns, + // Since we have a custom DialContext, we need to use this field to + // make golang http.Client attempt to use HTTP/2. Otherwise, it would + // only be used when negotiated on the TLS level. + ForceAttemptHTTP2: true, + } + + // Explicitly configure transport to use HTTP/2. + // + // See https://github.com/AdguardTeam/dnsproxy/issues/11. + var transportH2 *http2.Transport + transportH2, err = http2.ConfigureTransports(transport) + if err != nil { + return nil, err + } + + // Enable HTTP/2 pings on idle connections. + transportH2.ReadIdleTimeout = transportDefaultReadIdleTimeout + + return transport, nil +} + +// http3Transport is a wrapper over *http3.RoundTripper that tries to optimize +// its behavior. The main thing that it does is trying to force use a single +// connection to a host instead of creating a new one all the time. It also +// helps mitigate race issues with quic-go. +type http3Transport struct { + baseTransport *http3.RoundTripper + + closed bool + mu sync.RWMutex +} + +// type check +var _ http.RoundTripper = (*http3Transport)(nil) + +// RoundTrip implements the http.RoundTripper interface for *http3Transport. +func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + h.mu.RLock() + defer h.mu.RUnlock() + + if h.closed { + return nil, net.ErrClosed + } + + // Try to use cached connection to the target host if it's available. + resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true}) + + if errors.Is(err, http3.ErrNoCachedConn) { + // If there are no cached connection, trigger creating a new one. + resp, err = h.baseTransport.RoundTrip(req) + } + + return resp, err +} + +// type check +var _ io.Closer = (*http3Transport)(nil) + +// Close implements the io.Closer interface for *http3Transport. +func (h *http3Transport) Close() (err error) { + h.mu.Lock() + defer h.mu.Unlock() + + h.closed = true + + return h.baseTransport.Close() +} + +// createTransportH3 tries to create an HTTP/3 transport for this upstream. +// We should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or +// if it is too slow. In order to do that, this method will run two probes +// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it +// will create the *http3.RoundTripper instance. +func (doh *dnsOverHTTPS) createTransportH3( + tlsConfig *tls.Config, + dialContext dialHandler, +) (roundTripper http.RoundTripper, err error) { + if !doh.supportsH3() { + return nil, errors.New("HTTP3 support is not enabled") + } + + addr, err := doh.probeH3(tlsConfig, dialContext) + if err != nil { + return nil, err + } + + rt := &http3.RoundTripper{ + Dial: func( + ctx context.Context, + + // Ignore the address and always connect to the one that we got + // from the bootstrapper. + _ string, + tlsCfg *tls.Config, + cfg *quic.Config, + ) (c quic.EarlyConnection, err error) { + return doh.dialQuic(ctx, addr, tlsCfg, cfg) + }, + DisableCompression: true, + TLSClientConfig: tlsConfig, + QuicConfig: doh.getQUICConfig(), + } + + return &http3Transport{baseTransport: rt}, nil +} + +func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + ip, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + portInt, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + udpAddr := net.UDPAddr{ + IP: net.ParseIP(ip), + Port: portInt, + } + var conn net.PacketConn + if doh.proxyAdapter == "" { + conn, err = dialer.ListenPacket(ctx, "udp", "") + if err != nil { + return nil, err } } else { - transport = &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } + if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", udpAddr.AddrPort().Addr(), port); err == nil { + if pc, ok := wrapConn.(*wrapPacketConn); ok { + conn = pc + } else { + return nil, fmt.Errorf("conn isn't wrapPacketConn") + } + } else { + return nil, err + } + } + return quic.DialEarlyContext(ctx, conn, &udpAddr, doh.url.Host, tlsCfg, cfg) +} - ip, err := resolver.ResolveIPWithResolver(host, r) - if err != nil { - return nil, err - } +// probeH3 runs a test to check whether QUIC is faster than TLS for this +// upstream. If the test is successful it will return the address that we +// should use to establish the QUIC connections. +func (p *dnsOverHTTPS) probeH3( + tlsConfig *tls.Config, + dialContext dialHandler, +) (addr string, err error) { + // We're using bootstrapped address instead of what's passed to the function + // it does not create an actual connection, but it helps us determine + // what IP is actually reachable (when there are v4/v6 addresses). + rawConn, err := dialContext(context.Background(), "udp", p.url.Host) + if err != nil { + return "", fmt.Errorf("failed to dial: %w", err) + } + // It's never actually used. + _ = rawConn.Close() - if proxyAdapter == "" { - return dialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), port)) - } else { - return dialContextExtra(ctx, proxyAdapter, "tcp", ip, port) - } - }, - TLSClientConfig: TLCConfig, + udpConn, ok := rawConn.(*net.UDPConn) + if !ok { + return "", fmt.Errorf("not a UDP connection to %s", p.Address()) + } + + addr = udpConn.RemoteAddr().String() + + // Avoid spending time on probing if this upstream only supports HTTP/3. + if p.supportsH3() && !p.supportsHTTP() { + return addr, nil + } + + // Use a new *tls.Config with empty session cache for probe connections. + // Surprisingly, this is really important since otherwise it invalidates + // the existing cache. + // TODO(ameshkov): figure out why the sessions cache invalidates here. + probeTLSCfg := tlsConfig.Clone() + probeTLSCfg.ClientSessionCache = nil + + // Do not expose probe connections to the callbacks that are passed to + // the bootstrap options to avoid side-effects. + // TODO(ameshkov): consider exposing, somehow mark that this is a probe. + probeTLSCfg.VerifyPeerCertificate = nil + probeTLSCfg.VerifyConnection = nil + + // Run probeQUIC and probeTLS in parallel and see which one is faster. + chQuic := make(chan error, 1) + chTLS := make(chan error, 1) + go p.probeQUIC(addr, probeTLSCfg, chQuic) + go p.probeTLS(dialContext, probeTLSCfg, chTLS) + + select { + case quicErr := <-chQuic: + if quicErr != nil { + // QUIC failed, return error since HTTP3 was not preferred. + return "", quicErr + } + + // Return immediately, QUIC was faster. + return addr, quicErr + case tlsErr := <-chTLS: + if tlsErr != nil { + // Return immediately, TLS failed. + log.Debugln("probing TLS: %v", tlsErr) + return addr, nil + } + + return "", errors.New("TLS was faster than QUIC, prefer it") + } +} + +// probeQUIC attempts to establish a QUIC connection to the specified address. +// We run probeQUIC and probeTLS in parallel and see which one is faster. +func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) { + startTime := time.Now() + + timeout := DefaultTimeout + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + defer cancel() + + conn, err := p.dialQuic(ctx, addr, tlsConfig, p.getQUICConfig()) + if err != nil { + ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.Address(), err) + return + } + + // Ignore the error since there's no way we can use it for anything useful. + _ = conn.CloseWithError(QUICCodeNoError, "") + + ch <- nil + + elapsed := time.Now().Sub(startTime) + log.Debugln("elapsed on establishing a QUIC connection: %s", elapsed) +} + +// probeTLS attempts to establish a TLS connection to the specified address. We +// run probeQUIC and probeTLS in parallel and see which one is faster. +func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) { + startTime := time.Now() + + conn, err := p.tlsDial(dialContext, "tcp", tlsConfig) + if err != nil { + ch <- fmt.Errorf("opening TLS connection: %w", err) + return + } + + // Ignore the error since there's no way we can use it for anything useful. + _ = conn.Close() + + ch <- nil + + elapsed := time.Now().Sub(startTime) + log.Debugln("elapsed on establishing a TLS connection: %s", elapsed) +} + +// supportsH3 returns true if HTTP/3 is supported by this upstream. +func (p *dnsOverHTTPS) supportsH3() (ok bool) { + for _, v := range p.supportedHTTPVersions() { + if v == C.HTTPVersion3 { + return true } } - return &dohClient{ - url: url, - transport: transport, - } + return false +} + +// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream. +func (p *dnsOverHTTPS) supportsHTTP() (ok bool) { + for _, v := range p.supportedHTTPVersions() { + if v == C.HTTPVersion11 || v == C.HTTPVersion2 { + return true + } + } + + return false +} + +// supportedHTTPVersions returns the list of supported HTTP versions. +func (p *dnsOverHTTPS) supportedHTTPVersions() (v []C.HTTPVersion) { + v = p.httpVersions + if v == nil { + v = DefaultHTTPVersions + } + + return v +} + +// isHTTP3 checks if the *http.Client is an HTTP/3 client. +func isHTTP3(client *http.Client) (ok bool) { + _, ok = client.Transport.(*http3Transport) + + return ok +} + +// tlsDial is basically the same as tls.DialWithDialer, but we will call our own +// dialContext function to get connection. +func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) { + // We're using bootstrapped address instead of what's passed + // to the function. + rawConn, err := dialContext(context.Background(), network, doh.url.Host) + if err != nil { + return nil, err + } + + // We want the timeout to cover the whole process: TCP connection and + // TLS handshake dialTimeout will be used as connection deadLine. + conn := tls.Client(rawConn, config) + + err = conn.SetDeadline(time.Now().Add(dialTimeout)) + if err != nil { + // Must not happen in normal circumstances. + panic(fmt.Errorf("cannot set deadline: %w", err)) + } + + err = conn.Handshake() + if err != nil { + defer conn.Close() + return nil, err + } + + return conn, nil } diff --git a/dns/doq.go b/dns/doq.go index 7807de1c..734d26d0 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -1,134 +1,307 @@ package dns import ( - "bytes" "context" "crypto/tls" + "encoding/binary" + "errors" "fmt" - "github.com/Dreamacro/clash/component/dialer" - "github.com/Dreamacro/clash/component/resolver" - tlsC "github.com/Dreamacro/clash/component/tls" - "github.com/lucas-clemente/quic-go" "net" + "net/netip" + "runtime" "strconv" "sync" "time" + "github.com/Dreamacro/clash/component/dialer" + tlsC "github.com/Dreamacro/clash/component/tls" + "github.com/lucas-clemente/quic-go" + "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" ) const NextProtoDQ = "doq" +const ( + // QUICCodeNoError is used when the connection or stream needs to be closed, + // but there is no error to signal. + QUICCodeNoError = quic.ApplicationErrorCode(0) + // QUICCodeInternalError signals that the DoQ implementation encountered + // an internal error and is incapable of pursuing the transaction or the + // connection. + QUICCodeInternalError = quic.ApplicationErrorCode(1) + // QUICKeepAlivePeriod is the value that we pass to *quic.Config and that + // controls the period with with keep-alive frames are being sent to the + // connection. We set it to 20s as it would be in the quic-go@v0.27.1 with + // KeepAlive field set to true This value is specified in + // https://pkg.go.dev/github.com/lucas-clemente/quic-go/internal/protocol#MaxKeepAliveInterval. + // + // TODO(ameshkov): Consider making it configurable. + QUICKeepAlivePeriod = time.Second * 20 + DefaultTimeout = time.Second * 5 +) -var bytesPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }} +type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error) + +// dnsOverQUIC is a struct that implements the Upstream interface for the +// DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html). +type dnsOverQUIC struct { + // quicConfig is the QUIC configuration that is used for establishing + // connections to the upstream. This configuration includes the TokenStore + // that needs to be stored for the lifetime of dnsOverQUIC since we can + // re-create the connection. + quicConfig *quic.Config + quicConfigGuard sync.Mutex + + // conn is the current active QUIC connection. It can be closed and + // re-opened when needed. + conn quic.Connection + connMu sync.RWMutex + + // bytesPool is a *sync.Pool we use to store byte buffers in. These byte + // buffers are used to read responses from the upstream. + bytesPool *sync.Pool + bytesPoolGuard sync.Mutex -type quicClient struct { addr string - r *Resolver - connection quic.Connection proxyAdapter string - udp net.PacketConn - sync.RWMutex // protects connection and bytesPool + r *Resolver } -func newDOQ(r *Resolver, addr, proxyAdapter string) *quicClient { - return &quicClient{ +// type check +var _ dnsClient = (*dnsOverQUIC)(nil) + +// newDoQ returns the DNS-over-QUIC Upstream. +func newDoQ(resolver *Resolver, addr string, adapter string) (dnsClient, error) { + doq := &dnsOverQUIC{ addr: addr, - r: r, - proxyAdapter: proxyAdapter, + proxyAdapter: adapter, + r: resolver, + quicConfig: &quic.Config{ + KeepAlivePeriod: QUICKeepAlivePeriod, + TokenStore: newQUICTokenStore(), + }, } + + runtime.SetFinalizer(doq, (*dnsOverQUIC).Close) + return doq, nil } -func (dc *quicClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { - return dc.ExchangeContext(context.Background(), m) -} +// Address implements the Upstream interface for *dnsOverQUIC. +func (p *dnsOverQUIC) Address() string { return p.addr } + +func (p *dnsOverQUIC) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { + // When sending queries over a QUIC connection, the DNS Message ID MUST be + // set to zero. + id := m.Id + m.Id = 0 + defer func() { + // Restore the original ID to not break compatibility with proxies. + m.Id = id + if msg != nil { + msg.Id = id + } + }() + + // Check if there was already an active conn before sending the request. + // We'll only attempt to re-connect if there was one. + hasConnection := p.hasConnection() + + // Make the first attempt to send the DNS query. + msg, err = p.exchangeQUIC(ctx, m) + + // Make up to 2 attempts to re-open the QUIC connection and send the request + // again. There are several cases where this workaround is necessary to + // make DoQ usable. We need to make 2 attempts in the case when the + // connection was closed (due to inactivity for example) AND the server + // refuses to open a 0-RTT connection. + for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ { + log.Debugln("re-creating the QUIC connection and retrying due to %v", err) + + // Close the active connection to make sure we'll try to re-connect. + p.closeConnWithError(err) + + // Retry sending the request. + msg, err = p.exchangeQUIC(ctx, m) + } -func (dc *quicClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { - stream, err := dc.openStream(ctx) if err != nil { - return nil, fmt.Errorf("failed to open new stream to %s", dc.addr) + // If we're unable to exchange messages, make sure the connection is + // closed and signal about an internal error. + p.closeConnWithError(err) } - buf, err := m.Pack() + return msg, err +} + +// Exchange implements the Upstream interface for *dnsOverQUIC. +func (p *dnsOverQUIC) Exchange(m *D.Msg) (msg *D.Msg, err error) { + return p.ExchangeContext(context.Background(), m) +} + +// Close implements the Upstream interface for *dnsOverQUIC. +func (p *dnsOverQUIC) Close() (err error) { + p.connMu.Lock() + defer p.connMu.Unlock() + + runtime.SetFinalizer(p, nil) + + if p.conn != nil { + err = p.conn.CloseWithError(QUICCodeNoError, "") + } + + return err +} + +// exchangeQUIC attempts to open a QUIC connection, send the DNS message +// through it and return the response it got from the server. +func (p *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) { + var conn quic.Connection + conn, err = p.getConnection(true) if err != nil { return nil, err } - _, err = stream.Write(buf) + var buf []byte + buf, err = msg.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err) + } + + var stream quic.Stream + stream, err = p.openStream(ctx, conn) if err != nil { return nil, err } + _, err = stream.Write(AddPrefix(buf)) + if err != nil { + return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err) + } + // The client MUST send the DNS query over the selected stream, and MUST // indicate through the STREAM FIN mechanism that no further data will - // be sent on that stream. - // stream.Close() -- closes the write-direction of the stream. + // be sent on that stream. Note, that stream.Close() closes the + // write-direction of the stream, but does not prevent reading from it. _ = stream.Close() - respBuf := bytesPool.Get().(*bytes.Buffer) - defer bytesPool.Put(respBuf) - defer respBuf.Reset() - - n, err := respBuf.ReadFrom(stream) - if err != nil && n == 0 { - return nil, err - } - - reply := new(D.Msg) - err = reply.Unpack(respBuf.Bytes()) - if err != nil { - return nil, err - } - - return reply, nil + return p.readMsg(stream) } -func isActive(s quic.Connection) bool { - select { - case <-s.Context().Done(): - return false - default: - return true - } +// AddPrefix adds a 2-byte prefix with the DNS message length. +func AddPrefix(b []byte) (m []byte) { + m = make([]byte, 2+len(b)) + binary.BigEndian.PutUint16(m, uint16(len(b))) + copy(m[2:], b) + + return m } -// getConnection - opens or returns an existing quic.Connection -// useCached - if true and cached connection exists, return it right away -// otherwise - forcibly creates a new connection -func (dc *quicClient) getConnection(ctx context.Context) (quic.Connection, error) { - var connection quic.Connection - dc.RLock() - connection = dc.connection +// shouldRetry checks what error we received and decides whether it is required +// to re-open the connection and retry sending the request. +func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) { + return isQUICRetryError(err) +} - if connection != nil && isActive(connection) { - dc.RUnlock() - return connection, nil - } +// getBytesPool returns (creates if needed) a pool we store byte buffers in. +func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) { + p.bytesPoolGuard.Lock() + defer p.bytesPoolGuard.Unlock() - dc.RUnlock() + if p.bytesPool == nil { + p.bytesPool = &sync.Pool{ + New: func() interface{} { + b := make([]byte, MaxMsgSize) - dc.Lock() - defer dc.Unlock() - connection = dc.connection - if connection != nil { - if isActive(connection) { - return connection, nil - } else { - _ = connection.CloseWithError(quic.ApplicationErrorCode(0), "") + return &b + }, } } - var err error - connection, err = dc.openConnection(ctx) - dc.connection = connection - return connection, err + return p.bytesPool } -func (dc *quicClient) openConnection(ctx context.Context) (quic.Connection, error) { - if dc.udp != nil { - _ = dc.udp.Close() +// getConnection opens or returns an existing quic.Connection. useCached +// argument controls whether we should try to use the existing cached +// connection. If it is false, we will forcibly create a new connection and +// close the existing one if needed. +func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) { + var conn quic.Connection + p.connMu.RLock() + conn = p.conn + if conn != nil && useCached { + p.connMu.RUnlock() + + return conn, nil + } + if conn != nil { + // we're recreating the connection, let's create a new one. + _ = conn.CloseWithError(QUICCodeNoError, "") + } + p.connMu.RUnlock() + + p.connMu.Lock() + defer p.connMu.Unlock() + + var err error + conn, err = p.openConnection() + if err != nil { + return nil, err + } + p.conn = conn + + return conn, nil +} + +// hasConnection returns true if there's an active QUIC connection. +func (p *dnsOverQUIC) hasConnection() (ok bool) { + p.connMu.Lock() + defer p.connMu.Unlock() + + return p.conn != nil +} + +// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that +// this method returns a pointer, it is forbidden to change its properties. +func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) { + p.quicConfigGuard.Lock() + defer p.quicConfigGuard.Unlock() + + return p.quicConfig +} + +// resetQUICConfig re-creates the tokens store as we may need to use a new one +// if we failed to connect. +func (p *dnsOverQUIC) resetQUICConfig() { + p.quicConfigGuard.Lock() + defer p.quicConfigGuard.Unlock() + + p.quicConfig = p.quicConfig.Clone() + p.quicConfig.TokenStore = newQUICTokenStore() +} + +// openStream opens a new QUIC stream for the specified connection. +func (p *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (quic.Stream, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := conn.OpenStreamSync(ctx) + if err == nil { + return stream, nil } + // We can get here if the old QUIC connection is not valid anymore. We + // should try to re-create the connection again in this case. + newConn, err := p.getConnection(false) + if err != nil { + return nil, err + } + // Open a new stream. + return newConn.OpenStreamSync(ctx) +} + +// openConnection opens a new QUIC connection. +func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) { tlsConfig := tlsC.GetGlobalFingerprintTLCConfig( &tls.Config{ InsecureSkipVerify: false, @@ -137,42 +310,45 @@ func (dc *quicClient) openConnection(ctx context.Context) (quic.Connection, erro }, SessionTicketsDisabled: false, }) - - quicConfig := &quic.Config{ - ConnectionIDLength: 12, - HandshakeIdleTimeout: time.Second * 8, - MaxIncomingStreams: 4, - KeepAlivePeriod: 10 * time.Second, - MaxIdleTimeout: time.Second * 120, - } - - log.Debugln("opening new connection to %s", dc.addr) - var ( - udp net.PacketConn - err error - ) - - host, port, err := net.SplitHostPort(dc.addr) - + // we're using bootstrapped address instead of what's passed to the function + // it does not create an actual connection, but it helps us determine + // what IP is actually reachable (when there're v4/v6 addresses). + ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open a QUIC connection: %w", err) + } + // It's never actually used + _ = rawConn.Close() + cancel() + + udpConn, ok := rawConn.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("failed to open connection to %s", doq.addr) } - ip, err := resolver.ResolveIPv4WithResolver(host, dc.r) + addr := udpConn.RemoteAddr().String() + + ip, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } p, err := strconv.Atoi(port) - udpAddr := net.UDPAddr{IP: ip.AsSlice(), Port: p} - - if dc.proxyAdapter == "" { + udpAddr := net.UDPAddr{IP: net.ParseIP(ip), Port: p} + var udp net.PacketConn + if doq.proxyAdapter == "" { udp, err = dialer.ListenPacket(ctx, "udp", "") if err != nil { return nil, err } } else { - conn, err := dialContextExtra(ctx, dc.proxyAdapter, "udp", ip, port) + ipAddr, err := netip.ParseAddr(ip) + if err != nil { + return nil, err + } + + conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port) if err != nil { return nil, err } @@ -185,21 +361,158 @@ func (dc *quicClient) openConnection(ctx context.Context) (quic.Connection, erro udp = wrapConn } - session, err := quic.DialContext(ctx, udp, &udpAddr, host, tlsConfig, quicConfig) - if err != nil { - return nil, fmt.Errorf("failed to open QUIC connection: %w", err) - } - - dc.udp = udp - return session, nil -} - -func (dc *quicClient) openStream(ctx context.Context) (quic.Stream, error) { - session, err := dc.getConnection(ctx) + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + host, _, err := net.SplitHostPort(doq.addr) if err != nil { return nil, err } - // open a new stream - return session.OpenStreamSync(ctx) + conn, err = quic.DialContext(ctx, udp, &udpAddr, host, tlsConfig, doq.getQUICConfig()) + if err != nil { + return nil, fmt.Errorf("opening quic connection to %s: %w", doq.addr, err) + } + + return conn, nil +} + +// closeConnWithError closes the active connection with error to make sure that +// new queries were processed in another connection. We can do that in the case +// of a fatal error. +func (p *dnsOverQUIC) closeConnWithError(err error) { + p.connMu.Lock() + defer p.connMu.Unlock() + + if p.conn == nil { + // Do nothing, there's no active conn anyways. + return + } + + code := QUICCodeNoError + if err != nil { + code = QUICCodeInternalError + } + + if errors.Is(err, quic.Err0RTTRejected) { + // Reset the TokenStore only if 0-RTT was rejected. + p.resetQUICConfig() + } + + err = p.conn.CloseWithError(code, "") + if err != nil { + log.Errorln("failed to close the conn: %v", err) + } + p.conn = nil +} + +// readMsg reads the incoming DNS message from the QUIC stream. +func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) { + pool := p.getBytesPool() + bufPtr := pool.Get().(*[]byte) + + defer pool.Put(bufPtr) + + respBuf := *bufPtr + n, err := stream.Read(respBuf) + if err != nil && n == 0 { + return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err) + } + + // All DNS messages (queries and responses) sent over DoQ connections MUST + // be encoded as a 2-octet length field followed by the message content as + // specified in [RFC1035]. + // IMPORTANT: Note, that we ignore this prefix here as this implementation + // does not support receiving multiple messages over a single connection. + m = new(D.Msg) + err = m.Unpack(respBuf[2:]) + if err != nil { + return nil, fmt.Errorf("unpacking response from %s: %w", p.Address(), err) + } + + return m, nil +} + +// newQUICTokenStore creates a new quic.TokenStore that is necessary to have +// in order to benefit from 0-RTT. +func newQUICTokenStore() (s quic.TokenStore) { + // You can read more on address validation here: + // https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 + // Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is + // more than enough for the way we use it (one connection per upstream). + return quic.NewLRUTokenStore(1, 10) +} + +// isQUICRetryError checks the error and determines whether it may signal that +// we should re-create the QUIC connection. This requirement is caused by +// quic-go issues, see the comments inside this function. +// TODO(ameshkov): re-test when updating quic-go. +func isQUICRetryError(err error) (ok bool) { + var qAppErr *quic.ApplicationError + if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 { + // This error is often returned when the server has been restarted, + // and we try to use the same connection on the client-side. It seems, + // that the old connections aren't closed immediately on the server-side + // and that's why one can run into this. + // In addition to that, quic-go HTTP3 client implementation does not + // clean up dead connections (this one is specific to DoH3 upstream): + // https://github.com/lucas-clemente/quic-go/issues/765 + return true + } + + var qIdleErr *quic.IdleTimeoutError + if errors.As(err, &qIdleErr) { + // This error means that the connection was closed due to being idle. + // In this case we should forcibly re-create the QUIC connection. + // Reproducing is rather simple, stop the server and wait for 30 seconds + // then try to send another request via the same upstream. + return true + } + + var resetErr *quic.StatelessResetError + if errors.As(err, &resetErr) { + // A stateless reset is sent when a server receives a QUIC packet that + // it doesn't know how to decrypt. For instance, it may happen when + // the server was recently rebooted. We should reconnect and try again + // in this case. + return true + } + + var qTransportError *quic.TransportError + if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError { + // A transport error with the NO_ERROR error code could be sent by the + // server when it considers that it's time to close the connection. + // For example, Google DNS eventually closes an active connection with + // the NO_ERROR code and "Connection max age expired" message: + // https://github.com/AdguardTeam/dnsproxy/issues/283 + return true + } + + if errors.Is(err, quic.Err0RTTRejected) { + // This error happens when we try to establish a 0-RTT connection with + // a token the server is no more aware of. This can be reproduced by + // restarting the QUIC server (it will clear its tokens cache). The + // next connection attempt will return this error until the client's + // tokens cache is purged. + return true + } + + return false +} + +func getDialHandler(r *Resolver, proxyAdapter string) dialHandler { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ip, err := r.ResolveIP(host) + if err != nil { + return nil, err + } + if len(proxyAdapter) == 0 { + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port), dialer.WithDirect()) + } else { + return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port, dialer.WithDirect()) + } + } } diff --git a/dns/resolver.go b/dns/resolver.go index 84a38034..1184c2e7 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -355,6 +355,7 @@ type NameServer struct { Interface *atomic.String ProxyAdapter string Params map[string]string + PreferH3 bool } type FallbackFilter struct { diff --git a/dns/util.go b/dns/util.go index 50d9decd..17e4f5cf 100644 --- a/dns/util.go +++ b/dns/util.go @@ -19,6 +19,10 @@ import ( D "github.com/miekg/dns" ) +const ( + MaxMsgSize = 65535 +) + func putMsgToCache(c *cache.LruCache[string, *D.Msg], key string, msg *D.Msg) { var ttl uint32 switch { @@ -59,13 +63,17 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { for _, s := range servers { switch s.Net { case "https": - ret = append(ret, newDoHClient(s.Addr, resolver, s.Params, s.ProxyAdapter)) + ret = append(ret, newDoHClient(s.Addr, resolver, s.PreferH3, s.Params, s.ProxyAdapter)) continue case "dhcp": ret = append(ret, newDHCPClient(s.Addr)) continue case "quic": - ret = append(ret, newDOQ(resolver, s.Addr, s.ProxyAdapter)) + if doq, err := newDoQ(resolver, s.Addr, s.ProxyAdapter); err == nil { + ret = append(ret, doq) + }else{ + log.Fatalln("DoQ format error: %v",err) + } continue }