diff --git a/adapter/outbound/dns.go b/adapter/outbound/dns.go index 405392a1..21a5b2b7 100644 --- a/adapter/outbound/dns.go +++ b/adapter/outbound/dns.go @@ -2,10 +2,10 @@ package outbound import ( "context" - "fmt" "net" "time" + N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/resolver" @@ -24,7 +24,9 @@ type DnsOption struct { // DialContext implements C.ProxyAdapter func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { - return nil, fmt.Errorf("dns outbound does not support tcp") + left, right := N.Pipe() + go resolver.RelayDnsConn(context.Background(), right, 0) + return NewConn(left, d), nil } // ListenPacketContext implements C.ProxyAdapter @@ -76,29 +78,44 @@ func (d *dnsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + select { + case <-d.ctx.Done(): + return 0, net.ErrClosed + default: + } + + if len(p) > resolver.SafeDnsPacketSize { + // wtf??? + return len(p), nil + } + ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout) defer cancel() buf := pool.Get(resolver.SafeDnsPacketSize) put := func() { _ = pool.Put(buf) } - buf, err = resolver.RelayDnsPacket(ctx, p, buf) - if err != nil { - put() - return 0, err - } + copy(buf, p) // avoid p be changed after WriteTo returned - packet := dnsPacket{ - data: buf, - put: put, - addr: addr, - } - select { - case d.response <- packet: - return len(p), nil - case <-d.ctx.Done(): - put() - return 0, net.ErrClosed - } + go func() { // don't block the WriteTo function + buf, err = resolver.RelayDnsPacket(ctx, buf[:len(p)], buf) + if err != nil { + put() + return + } + + packet := dnsPacket{ + data: buf, + put: put, + addr: addr, + } + select { + case d.response <- packet: + break + case <-d.ctx.Done(): + put() + } + }() + return len(p), nil } func (d *dnsPacketConn) Close() error { diff --git a/common/net/deadline/conn.go b/common/net/deadline/conn.go index e8446ce2..fdf9334f 100644 --- a/common/net/deadline/conn.go +++ b/common/net/deadline/conn.go @@ -26,6 +26,11 @@ type Conn struct { resultCh chan *connReadResult } +func IsConn(conn any) bool { + _, ok := conn.(*Conn) + return ok +} + func NewConn(conn net.Conn) *Conn { c := &Conn{ ExtendedConn: bufio.NewExtendedConn(conn), diff --git a/common/net/deadline/pipe_sing.go b/common/net/deadline/pipe_sing.go index 20721fad..0f6d378d 100644 --- a/common/net/deadline/pipe_sing.go +++ b/common/net/deadline/pipe_sing.go @@ -215,3 +215,8 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) { return nil, os.ErrDeadlineExceeded } } + +func IsPipe(conn any) bool { + _, ok := conn.(*pipe) + return ok +} diff --git a/common/net/sing.go b/common/net/sing.go index 3296ad5b..d726f440 100644 --- a/common/net/sing.go +++ b/common/net/sing.go @@ -23,6 +23,12 @@ type ExtendedReader = network.ExtendedReader var WriteBuffer = bufio.WriteBuffer func NewDeadlineConn(conn net.Conn) ExtendedConn { + if deadline.IsPipe(conn) || deadline.IsPipe(network.UnwrapReader(conn)) { + return NewExtendedConn(conn) // pipe always have correctly deadline implement + } + if deadline.IsConn(conn) || deadline.IsConn(network.UnwrapReader(conn)) { + return NewExtendedConn(conn) // was a *deadline.Conn + } return deadline.NewConn(conn) } diff --git a/component/resolver/relay.go b/component/resolver/relay.go index 3bc54445..27b25af1 100644 --- a/component/resolver/relay.go +++ b/component/resolver/relay.go @@ -17,15 +17,15 @@ const DefaultDnsRelayTimeout = time.Second * 5 const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough -func RelayDnsConn(ctx context.Context, conn net.Conn) error { +func RelayDnsConn(ctx context.Context, conn net.Conn, readTimeout time.Duration) error { buff := pool.Get(pool.UDPBufferSize) defer func() { _ = pool.Put(buff) _ = conn.Close() }() for { - if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil { - break + if readTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) } length := uint16(0) diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index 86237daa..42926732 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -37,7 +37,7 @@ func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool { func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { if h.ShouldHijackDns(metadata.Destination.AddrPort()) { log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String()) - return resolver.RelayDnsConn(ctx, conn) + return resolver.RelayDnsConn(ctx, conn, resolver.DefaultDnsReadTimeout) } return h.ListenerHandler.NewConnection(ctx, conn, metadata) }