diff --git a/adapter/outbound/dns.go b/adapter/outbound/dns.go index 14eaf581..94819749 100644 --- a/adapter/outbound/dns.go +++ b/adapter/outbound/dns.go @@ -7,10 +7,12 @@ import ( "net/netip" "time" + "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" + D "github.com/miekg/dns" ) @@ -32,52 +34,78 @@ func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dia func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { log.Debugln("[DNS] hijack udp:%s from %s", metadata.RemoteAddress(), metadata.SourceAddrPort()) + ctx, cancel := context.WithCancel(context.Background()) + return newPacketConn(&dnsPacketConn{ - response: make(chan []byte), - doneReading: make(chan int), + response: make(chan dnsPacket, 1), + ctx: ctx, + cancel: cancel, }, d), nil } +type dnsPacket struct { + data []byte + put func() + addr net.Addr +} + // dnsPacketConn implements net.PacketConn type dnsPacketConn struct { - response chan []byte - writeTo net.Addr - doneReading chan int + response chan dnsPacket + ctx context.Context + cancel context.CancelFunc +} + +func (d *dnsPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + select { + case packet := <-d.response: + return packet.data, packet.put, packet.addr, nil + case <-d.ctx.Done(): + return nil, nil, nil, net.ErrClosed + } } func (d *dnsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - buf := <-d.response - - log.Debugln("[DNS] hijack ReadFrom, len %d", len(buf)) - - if buf != nil { - n := copy(p, buf) - return n, d.writeTo, nil + select { + case packet := <-d.response: + n = copy(p, packet.data) + if packet.put != nil { + packet.put() + } + return n, packet.addr, nil + case <-d.ctx.Done(): + return 0, nil, net.ErrClosed } - - return 0, nil, fmt.Errorf("read from closed dns packet conn") } func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - log.Debugln("[DNS] hijack WriteTo %s, len %d", addr.String(), len(p)) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(d.ctx, time.Second*5) defer cancel() - buf, err := RelayDnsPacket(ctx, p, make([]byte, 4096)) + buf := pool.Get(2048) + put := func() { _ = pool.Put(buf) } + buf, err = RelayDnsPacket(ctx, p, buf) if err != nil { - log.Warnln("[DNS] dns hijack: relay dns packet: %s", err) + put() return 0, err } - d.writeTo = addr - d.response <- buf - - return len(p), nil + packet := dnsPacket{ + data: buf, + put: put, + addr: addr, + } + select { + case d.response <- packet: + return len(p), nil + case <-d.ctx.Done(): + put() + return 0, net.ErrClosed + } } func (d *dnsPacketConn) Close() error { - close(d.response) + d.cancel() return nil } @@ -101,7 +129,7 @@ func NewDnsWithOption(option DnsOption) *Dns { return &Dns{ Base: &Base{ name: option.Name, - tp: C.Direct, + tp: C.Dns, udp: true, tfo: option.TFO, mpTcp: option.MPTCP, @@ -130,14 +158,3 @@ func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, r.Compress = true return r.PackBuffer(target) } - -func NewDns() *Dns { - return &Dns{ - Base: &Base{ - name: "DNS", - tp: C.Dns, - udp: true, - prefer: C.DualStack, - }, - } -}