chore: wireguard using internal dialer

This commit is contained in:
wwqgtxx 2023-03-07 09:30:51 +08:00
parent 545a79d406
commit 9cc7fdaca9
3 changed files with 57 additions and 23 deletions

View file

@ -34,7 +34,7 @@ type WireGuard struct {
bind *wireguard.ClientBind
device *device.Device
tunDevice wireguard.Device
dialer *wgDialer
dialer *wgSingDialer
startOnce sync.Once
startErr error
}
@ -56,16 +56,28 @@ type WireGuardOption struct {
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
}
type wgDialer struct {
options []dialer.Option
type wgSingDialer struct {
dialer dialer.Dialer
}
func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return dialer.DialContext(ctx, network, destination.String(), d.options...)
var _ N.Dialer = &wgSingDialer{}
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.dialer.DialContext(ctx, network, destination.String())
}
func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return dialer.ListenPacket(ctx, dialer.ParseNetwork("udp", destination.Addr), "", d.options...)
func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.dialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
}
type wgNetDialer struct {
tunDevice wireguard.Device
}
var _ dialer.NetDialer = &wgNetDialer{}
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address))
}
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
@ -79,7 +91,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
dialer: &wgDialer{},
dialer: &wgSingDialer{dialer: dialer.NewDialer()},
}
runtime.SetFinalizer(outbound, closeWireGuard)
@ -199,7 +211,8 @@ func closeWireGuard(w *WireGuard) {
}
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
w.dialer.options = opts
options := w.Base.DialOptions(opts...)
w.dialer.dialer = dialer.NewDialer(options...)
var conn net.Conn
w.startOnce.Do(func() {
w.startErr = w.tunDevice.Start()
@ -208,12 +221,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
return nil, w.startErr
}
if !metadata.Resolved() {
var addrs []netip.Addr
addrs, err = resolver.LookupIP(ctx, metadata.Host)
if err != nil {
return nil, err
}
conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs)
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
} else {
port, _ := strconv.Atoi(metadata.DstPort)
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, uint16(port)))
@ -228,7 +237,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts
}
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
w.dialer.options = opts
options := w.Base.DialOptions(opts...)
w.dialer.dialer = dialer.NewDialer(options...)
var pc net.PacketConn
w.startOnce.Do(func() {
w.startErr = w.tunDevice.Start()

View file

@ -109,7 +109,19 @@ func GetTcpConcurrent() bool {
}
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
dialer := &net.Dialer{}
address := net.JoinHostPort(destination.String(), port)
netDialer := opt.netDialer
switch netDialer.(type) {
case nil:
netDialer = &net.Dialer{}
case *net.Dialer:
netDialer = &*netDialer.(*net.Dialer) // make a copy
default:
return netDialer.DialContext(ctx, network, address)
}
dialer := netDialer.(*net.Dialer)
if opt.interfaceName != "" {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
@ -118,8 +130,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
address := net.JoinHostPort(destination.String(), port)
if opt.tfo {
return dialTFO(ctx, *dialer, network, address)
}
@ -307,15 +317,15 @@ func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
}
type Dialer struct {
Opt option
opt option
}
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DialContext(ctx, network, address, WithOption(d.Opt))
return DialContext(ctx, network, address, WithOption(d.opt))
}
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
opt := WithOption(d.Opt)
opt := WithOption(d.opt)
if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context."
opt = WithInterface("")
@ -325,5 +335,5 @@ func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddr
func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...)
return Dialer{Opt: *opt}
return Dialer{opt: *opt}
}

View file

@ -1,6 +1,9 @@
package dialer
import (
"context"
"net"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
@ -12,6 +15,10 @@ var (
DefaultRoutingMark = atomic.NewInt32(0)
)
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type option struct {
interfaceName string
addrReuse bool
@ -20,6 +27,7 @@ type option struct {
prefer int
tfo bool
resolver resolver.Resolver
netDialer NetDialer
}
type Option func(opt *option)
@ -76,6 +84,12 @@ func WithTFO(tfo bool) Option {
}
}
func WithNetDialer(netDialer NetDialer) Option {
return func(opt *option) {
opt.netDialer = netDialer
}
}
func WithOption(o option) Option {
return func(opt *option) {
*opt = o