From 9cc7fdaca9009d00abb41ee38f4eca0d70b6c45f Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 7 Mar 2023 09:30:51 +0800 Subject: [PATCH] chore: wireguard using internal dialer --- adapter/outbound/wireguard.go | 42 ++++++++++++++++++++++------------- component/dialer/dialer.go | 24 ++++++++++++++------ component/dialer/options.go | 14 ++++++++++++ 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index e3dafbbf..e5d7cf3f 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -34,7 +34,7 @@ type WireGuard struct { bind *wireguard.ClientBind device *device.Device tunDevice wireguard.Device - dialer *wgDialer + dialer *wgSingDialer startOnce sync.Once startErr error } @@ -56,16 +56,28 @@ type WireGuardOption struct { PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` } -type wgDialer struct { - options []dialer.Option +type wgSingDialer struct { + dialer dialer.Dialer } -func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return dialer.DialContext(ctx, network, destination.String(), d.options...) +var _ N.Dialer = &wgSingDialer{} + +func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + return d.dialer.DialContext(ctx, network, destination.String()) } -func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", destination.Addr), "", d.options...) +func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return d.dialer.ListenPacket(ctx, "udp", "", destination.AddrPort()) +} + +type wgNetDialer struct { + tunDevice wireguard.Device +} + +var _ dialer.NetDialer = &wgNetDialer{} + +func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address)) } func NewWireGuard(option WireGuardOption) (*WireGuard, error) { @@ -79,7 +91,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, - dialer: &wgDialer{}, + dialer: &wgSingDialer{dialer: dialer.NewDialer()}, } runtime.SetFinalizer(outbound, closeWireGuard) @@ -199,7 +211,8 @@ func closeWireGuard(w *WireGuard) { } func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { - w.dialer.options = opts + options := w.Base.DialOptions(opts...) + w.dialer.dialer = dialer.NewDialer(options...) var conn net.Conn w.startOnce.Do(func() { w.startErr = w.tunDevice.Start() @@ -208,12 +221,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts return nil, w.startErr } if !metadata.Resolved() { - var addrs []netip.Addr - addrs, err = resolver.LookupIP(ctx, metadata.Host) - if err != nil { - return nil, err - } - conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs) + options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice})) + conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress()) } else { port, _ := strconv.Atoi(metadata.DstPort) conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port))) @@ -228,7 +237,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts } func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { - w.dialer.options = opts + options := w.Base.DialOptions(opts...) + w.dialer.dialer = dialer.NewDialer(options...) var pc net.PacketConn w.startOnce.Do(func() { w.startErr = w.tunDevice.Start() diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index f53435fb..025f7034 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -109,7 +109,19 @@ func GetTcpConcurrent() bool { } func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { - dialer := &net.Dialer{} + address := net.JoinHostPort(destination.String(), port) + + netDialer := opt.netDialer + switch netDialer.(type) { + case nil: + netDialer = &net.Dialer{} + case *net.Dialer: + netDialer = &*netDialer.(*net.Dialer) // make a copy + default: + return netDialer.DialContext(ctx, network, address) + } + + dialer := netDialer.(*net.Dialer) if opt.interfaceName != "" { if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { return nil, err @@ -118,8 +130,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po if opt.routingMark != 0 { bindMarkToDialer(opt.routingMark, dialer, network, destination) } - - address := net.JoinHostPort(destination.String(), port) if opt.tfo { return dialTFO(ctx, *dialer, network, address) } @@ -307,15 +317,15 @@ func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) { } type Dialer struct { - Opt option + opt option } func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return DialContext(ctx, network, address, WithOption(d.Opt)) + return DialContext(ctx, network, address, WithOption(d.opt)) } func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { - opt := WithOption(d.Opt) + opt := WithOption(d.opt) if rAddrPort.Addr().Unmap().IsLoopback() { // avoid "The requested address is not valid in its context." opt = WithInterface("") @@ -325,5 +335,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr func NewDialer(options ...Option) Dialer { opt := applyOptions(options...) - return Dialer{Opt: *opt} + return Dialer{opt: *opt} } diff --git a/component/dialer/options.go b/component/dialer/options.go index 1c4e7bfc..372a2e63 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -1,6 +1,9 @@ package dialer import ( + "context" + "net" + "github.com/Dreamacro/clash/component/resolver" "go.uber.org/atomic" @@ -12,6 +15,10 @@ var ( DefaultRoutingMark = atomic.NewInt32(0) ) +type NetDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + type option struct { interfaceName string addrReuse bool @@ -20,6 +27,7 @@ type option struct { prefer int tfo bool resolver resolver.Resolver + netDialer NetDialer } type Option func(opt *option) @@ -76,6 +84,12 @@ func WithTFO(tfo bool) Option { } } +func WithNetDialer(netDialer NetDialer) Option { + return func(opt *option) { + opt.netDialer = netDialer + } +} + func WithOption(o option) Option { return func(opt *option) { *opt = o