From 17922dc85798b30d7e478f939cf60848f4f506bb Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Wed, 12 Apr 2023 11:09:31 +0800 Subject: [PATCH] chore: proxyDialer first using old function to let mux work --- component/proxydialer/proxydialer.go | 44 ++++++++-------------------- dns/util.go | 16 +++++----- 2 files changed, 20 insertions(+), 40 deletions(-) diff --git a/component/proxydialer/proxydialer.go b/component/proxydialer/proxydialer.go index 7fac628f..a32e54d1 100644 --- a/component/proxydialer/proxydialer.go +++ b/component/proxydialer/proxydialer.go @@ -45,23 +45,13 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) ( return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil } var conn C.Conn - switch p.proxy.SupportWithDialer() { - case C.ALLNet: - fallthrough - case C.TCP: + if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work + conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + } else { conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta) - if err != nil { - return nil, err - } - default: // fallback to old function - if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function - conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) - if err != nil { - return nil, err - } - } else { - return nil, C.ErrNotSupport - } + } + if err != nil { + return nil, err } if p.statistic { conn = statistic.NewTCPTracker(conn, statistic.DefaultManager, currentMeta, nil, 0, 0, false) @@ -81,23 +71,13 @@ func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) var pc C.PacketConn var err error currentMeta.NetWork = C.UDP - switch p.proxy.SupportWithDialer() { - case C.ALLNet: - fallthrough - case C.UDP: + if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work + pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt)) + } else { pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta) - if err != nil { - return nil, err - } - default: // fallback to old function - if d, ok := p.dialer.(dialer.Dialer); ok { // fallback to old function - pc, err = p.proxy.ListenPacketContext(ctx, currentMeta, dialer.WithOption(d.Opt)) - if err != nil { - return nil, err - } - } else { - return nil, C.ErrNotSupport - } + } + if err != nil { + return nil, err } if p.statistic { pc = statistic.NewUDPTracker(pc, statistic.DefaultManager, currentMeta, nil, 0, 0, false) diff --git a/dns/util.go b/dns/util.go index 9df4482b..bfd2e9ed 100644 --- a/dns/util.go +++ b/dns/util.go @@ -170,15 +170,15 @@ func getDialHandler(r *Resolver, proxyAdapter C.ProxyAdapter, proxyName string, Host: host, DstPort: port, } - if proxyAdapter.IsL3Protocol(metadata) { - dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) - if err != nil { - return nil, err - } - metadata.Host = "" - metadata.DstIP = dstIP - } if proxyAdapter != nil { + if proxyAdapter.IsL3Protocol(metadata) { // L3 proxy should resolve domain before to avoid loopback + dstIP, err := resolver.ResolveIPWithResolver(ctx, host, r) + if err != nil { + return nil, err + } + metadata.Host = "" + metadata.DstIP = dstIP + } return proxyAdapter.DialContext(ctx, metadata, opts...) } opts = append(opts, dialer.WithResolver(r))