diff --git a/tunnel/connection.go b/tunnel/connection.go index 0384e805..cf024c75 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -3,26 +3,17 @@ package tunnel import ( "errors" "net" + "net/netip" "time" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" - "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" ) func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { defer packet.Drop() - // local resolve UDP dns - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) - if err != nil { - return err - } - metadata.DstIP = ip - } - addr := metadata.UDPAddr() if addr == nil { return errors.New("udp addr invalid") @@ -37,7 +28,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr net.Addr) { +func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, fAddr netip.Addr) { buf := pool.Get(pool.UDPBufferSize) defer pool.Put(buf) defer natTable.Delete(key) @@ -50,11 +41,16 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n return } - if fAddr != nil { - from = fAddr + fromUDPAddr := from.(*net.UDPAddr) + if fAddr.IsValid() { + fromAddr, _ := netip.AddrFromSlice(fromUDPAddr.IP) + fromAddr.Unmap() + if oAddr == fromAddr { + fromUDPAddr.IP = fAddr.AsSlice() + } } - _, err = packet.WriteBack(buf[:n], from) + _, err = packet.WriteBack(buf[:n], fromUDPAddr) if err != nil { return } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 6a39ce70..4a4685b2 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" "strconv" "sync" @@ -166,9 +167,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } // make a fAddr if request ip is fakeip - var fAddr net.Addr + var fAddr netip.Addr if resolver.IsExistFakeIP(metadata.DstIP) { - fAddr = metadata.UDPAddr() + fAddr, _ = netip.AddrFromSlice(metadata.DstIP) + fAddr = fAddr.Unmap() } if err := preHandleMetadata(metadata); err != nil { @@ -176,6 +178,15 @@ func handleUDPConn(packet *inbound.PacketAdapter) { return } + // local resolve UDP dns + if !metadata.Resolved() { + ip, err := resolver.ResolveIP(metadata.Host) + if err != nil { + return + } + metadata.DstIP = ip + } + key := packet.LocalAddr().String() handle := func() bool { @@ -240,7 +251,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) { log.Infoln("[UDP] %s --> %s doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.RemoteAddress()) } - go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr) + oAddr, _ := netip.AddrFromSlice(metadata.DstIP) + oAddr = oAddr.Unmap() + go handleUDPToLocal(packet.UDPPacket, pc, key, oAddr, fAddr) natTable.Set(key, pc) handle()