From 68753b4ae11c7ca508bf422180e248b3ccd81068 Mon Sep 17 00:00:00 2001 From: Dreamacro <8615343+Dreamacro@users.noreply.github.com> Date: Fri, 15 Oct 2021 21:44:53 +0800 Subject: [PATCH] Chore: contexify ProxyAdapter ListenPacket --- .gitignore | 5 ++++- adapter/adapter.go | 18 +++++++++++++++--- adapter/outbound/base.go | 5 +++-- adapter/outbound/direct.go | 6 +++--- adapter/outbound/reject.go | 4 ++-- adapter/outbound/shadowsocks.go | 6 +++--- adapter/outbound/shadowsocksr.go | 6 +++--- adapter/outbound/socks5.go | 8 +++----- adapter/outbound/trojan.go | 6 ++---- adapter/outbound/vmess.go | 6 ++---- adapter/outboundgroup/fallback.go | 6 +++--- adapter/outboundgroup/loadbalance.go | 7 +++---- adapter/outboundgroup/selector.go | 6 +++--- adapter/outboundgroup/urltest.go | 6 +++--- constant/adapters.go | 19 +++++++++++++------ constant/provider/interface.go | 2 +- test/Makefile | 8 ++++++++ test/README.md | 4 ++-- test/clash_test.go | 7 ++++--- tunnel/tunnel.go | 27 ++++++++++++++++----------- 20 files changed, 96 insertions(+), 66 deletions(-) create mode 100644 test/Makefile diff --git a/.gitignore b/.gitignore index 0593cfd0..52efcc9b 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,7 @@ bin/* # Output of the go coverage tool, specifically when used with LiteIDE *.out -# dep +# go mod vendor vendor # GoLand @@ -20,3 +20,6 @@ vendor # macOS file .DS_Store + +# test suite +test/config/cache* diff --git a/adapter/adapter.go b/adapter/adapter.go index 526866a5..26330163 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -36,12 +36,24 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { // DialContext implements C.ProxyAdapter func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { conn, err := p.ProxyAdapter.DialContext(ctx, metadata) - if err != nil { - p.alive.Store(false) - } + p.alive.Store(err == nil) return conn, err } +// DialUDP implements C.ProxyAdapter +func (p *Proxy) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) + defer cancel() + return p.ListenPacketContext(ctx, metadata) +} + +// ListenPacketContext implements C.ProxyAdapter +func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata) + p.alive.Store(err == nil) + return pc, err +} + // DelayHistory implements C.Proxy func (p *Proxy) DelayHistory() []C.DelayHistory { queue := p.history.Copy() diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index a001d799..22c0142b 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -1,6 +1,7 @@ package outbound import ( + "context" "encoding/json" "errors" "net" @@ -30,8 +31,8 @@ func (b *Base) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { return c, errors.New("no support") } -// DialUDP implements C.ProxyAdapter -func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { +// ListenPacketContext implements C.ProxyAdapter +func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { return nil, errors.New("no support") } diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 4b53e306..42433c41 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -22,9 +22,9 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, return NewConn(c, d), nil } -// DialUDP implements C.ProxyAdapter -func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - pc, err := dialer.ListenPacket(context.Background(), "udp", "") +// ListenPacketContext implements C.ProxyAdapter +func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := dialer.ListenPacket(ctx, "udp", "") if err != nil { return nil, err } diff --git a/adapter/outbound/reject.go b/adapter/outbound/reject.go index 36750496..a97c6c71 100644 --- a/adapter/outbound/reject.go +++ b/adapter/outbound/reject.go @@ -19,8 +19,8 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, return NewConn(&NopConn{}, r), nil } -// DialUDP implements C.ProxyAdapter -func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { +// ListenPacketContext implements C.ProxyAdapter +func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { return nil, errors.New("match reject rule") } diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 64b431a9..0194954e 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -87,9 +87,9 @@ func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_ return NewConn(c, ss), err } -// DialUDP implements C.ProxyAdapter -func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - pc, err := dialer.ListenPacket(context.Background(), "udp", "") +// ListenPacketContext implements C.ProxyAdapter +func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := dialer.ListenPacket(ctx, "udp", "") if err != nil { return nil, err } diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index 9ba8bc22..fb0dd7a5 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -72,9 +72,9 @@ func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata) return NewConn(c, ssr), err } -// DialUDP implements C.ProxyAdapter -func (ssr *ShadowSocksR) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - pc, err := dialer.ListenPacket(context.Background(), "udp", "") +// ListenPacketContext implements C.ProxyAdapter +func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := dialer.ListenPacket(ctx, "udp", "") if err != nil { return nil, err } diff --git a/adapter/outbound/socks5.go b/adapter/outbound/socks5.go index d9739557..7714c5c5 100644 --- a/adapter/outbound/socks5.go +++ b/adapter/outbound/socks5.go @@ -76,10 +76,8 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Co return NewConn(c, ss), nil } -// DialUDP implements C.ProxyAdapter -func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() +// ListenPacketContext implements C.ProxyAdapter +func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { c, err := dialer.DialContext(ctx, "tcp", ss.addr) if err != nil { err = fmt.Errorf("%s connect error: %w", ss.addr, err) @@ -109,7 +107,7 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { return } - pc, err := dialer.ListenPacket(context.Background(), "udp", "") + pc, err := dialer.ListenPacket(ctx, "udp", "") if err != nil { return } diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index afed410f..b8ea49e6 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -88,8 +88,8 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con return NewConn(c, t), err } -// DialUDP implements C.ProxyAdapter -func (t *Trojan) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { +// ListenPacketContext implements C.ProxyAdapter +func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { var c net.Conn // grpc transport @@ -100,8 +100,6 @@ func (t *Trojan) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { } defer safeConnClose(c, err) } else { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() c, err = dialer.DialContext(ctx, "tcp", t.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 0dbb0f1b..209084ca 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -215,8 +215,8 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn return NewConn(c, v), err } -// DialUDP implements C.ProxyAdapter -func (v *Vmess) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { +// ListenPacketContext implements C.ProxyAdapter +func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { ip, err := resolver.ResolveIP(metadata.Host) @@ -237,8 +237,6 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { c, err = v.client.StreamConn(c, parseVmessAddr(metadata)) } else { - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() c, err = dialer.DialContext(ctx, "tcp", v.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) diff --git a/adapter/outboundgroup/fallback.go b/adapter/outboundgroup/fallback.go index da01d4de..3221b552 100644 --- a/adapter/outboundgroup/fallback.go +++ b/adapter/outboundgroup/fallback.go @@ -32,10 +32,10 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con return c, err } -// DialUDP implements C.ProxyAdapter -func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { +// ListenPacketContext implements C.ProxyAdapter +func (f *Fallback) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { proxy := f.findAliveProxy(true) - pc, err := proxy.DialUDP(metadata) + pc, err := proxy.ListenPacketContext(ctx, metadata) if err == nil { pc.AppendToChains(f) } diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index e529b53b..fb284010 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -82,8 +82,8 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c return } -// DialUDP implements C.ProxyAdapter -func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error) { +// ListenPacketContext implements C.ProxyAdapter +func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (pc C.PacketConn, err error) { defer func() { if err == nil { pc.AppendToChains(lb) @@ -91,8 +91,7 @@ func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error }() proxy := lb.Unwrap(metadata) - - return proxy.DialUDP(metadata) + return proxy.ListenPacketContext(ctx, metadata) } // SupportUDP implements C.ProxyAdapter diff --git a/adapter/outboundgroup/selector.go b/adapter/outboundgroup/selector.go index 12453927..008e8af8 100644 --- a/adapter/outboundgroup/selector.go +++ b/adapter/outboundgroup/selector.go @@ -28,9 +28,9 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con return c, err } -// DialUDP implements C.ProxyAdapter -func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - pc, err := s.selectedProxy(true).DialUDP(metadata) +// ListenPacketContext implements C.ProxyAdapter +func (s *Selector) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := s.selectedProxy(true).ListenPacketContext(ctx, metadata) if err == nil { pc.AppendToChains(s) } diff --git a/adapter/outboundgroup/urltest.go b/adapter/outboundgroup/urltest.go index 2a11cd23..b27f12a4 100644 --- a/adapter/outboundgroup/urltest.go +++ b/adapter/outboundgroup/urltest.go @@ -42,9 +42,9 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co return c, err } -// DialUDP implements C.ProxyAdapter -func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { - pc, err := u.fast(true).DialUDP(metadata) +// ListenPacketContext implements C.ProxyAdapter +func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + pc, err := u.fast(true).ListenPacketContext(ctx, metadata) if err == nil { pc.AppendToChains(u) } diff --git a/constant/adapters.go b/constant/adapters.go index 733d86b9..2f7006e9 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -29,6 +29,7 @@ const ( const ( DefaultTCPTimeout = 5 * time.Second + DefaultUDPTimeout = DefaultTCPTimeout ) type Connection interface { @@ -73,11 +74,14 @@ type PacketConn interface { type ProxyAdapter interface { Name() string Type() AdapterType + Addr() string + SupportUDP() bool + MarshalJSON() ([]byte, error) // StreamConn wraps a protocol around net.Conn with Metadata. // // Examples: - // conn, _ := net.Dial("tcp", "host:port") + // conn, _ := net.DialContext(context.Background(), "tcp", "host:port") // conn, _ = adapter.StreamConn(conn, metadata) // // It returns a C.Conn with protocol which start with @@ -88,10 +92,8 @@ type ProxyAdapter interface { // contains multiplexing-related reuse logic (if any) DialContext(ctx context.Context, metadata *Metadata) (Conn, error) - DialUDP(metadata *Metadata) (PacketConn, error) - SupportUDP() bool - MarshalJSON() ([]byte, error) - Addr() string + ListenPacketContext(ctx context.Context, metadata *Metadata) (PacketConn, error) + // Unwrap extracts the proxy from a proxy-group. It returns nil when nothing to extract. Unwrap(metadata *Metadata) Proxy } @@ -105,9 +107,14 @@ type Proxy interface { ProxyAdapter Alive() bool DelayHistory() []DelayHistory - Dial(metadata *Metadata) (Conn, error) LastDelay() uint16 URLTest(ctx context.Context, url string) (uint16, error) + + // Deprecated: use DialContext instead. + Dial(metadata *Metadata) (Conn, error) + + // Deprecated: use DialPacketConn instead. + DialUDP(metadata *Metadata) (PacketConn, error) } // AdapterType is enum of adapter type diff --git a/constant/provider/interface.go b/constant/provider/interface.go index d3483cdc..53bda7ea 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -67,7 +67,7 @@ type ProxyProvider interface { Provider Proxies() []constant.Proxy // ProxiesWithTouch is used to inform the provider that the proxy is actually being used while getting the list of proxies. - // Commonly used in Dial and DialUDP + // Commonly used in DialContext and DialPacketConn ProxiesWithTouch() []constant.Proxy HealthCheck() } diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 00000000..012d88d5 --- /dev/null +++ b/test/Makefile @@ -0,0 +1,8 @@ +lint: + golangci-lint run --disable-all -E govet -E gofumpt -E megacheck ./... + +test: + go test -p 1 -v ./... + +benchmark: + go test -benchmem -run=^$ -bench . diff --git a/test/README.md b/test/README.md index 823fc544..a95f3aea 100644 --- a/test/README.md +++ b/test/README.md @@ -45,7 +45,7 @@ Prerequisite * docker (support Linux and macOS) ``` -$ go test -p 1 -v +$ make test ``` benchmark (Linux) @@ -55,5 +55,5 @@ benchmark (Linux) > (change chunkSize to measure the maximum throughput of clash on your machine) ``` -$ go test -benchmem -run=^$ -bench . +$ make benchmark ``` diff --git a/test/clash_test.go b/test/clash_test.go index 4011b04a..a3303a23 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -96,6 +96,7 @@ func init() { images := []string{ ImageShadowsocks, + ImageShadowsocksRust, ImageVmess, ImageTrojan, ImageSnell, @@ -582,7 +583,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { return } - pc, err := proxy.DialUDP(&C.Metadata{ + pc, err := proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, DstPort: "10001", @@ -595,7 +596,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { assert.NoError(t, testPingPongWithPacketConn(t, pc)) - pc, err = proxy.DialUDP(&C.Metadata{ + pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, DstPort: "10001", @@ -608,7 +609,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) { assert.NoError(t, testLargeDataWithPacketConn(t, pc)) - pc, err = proxy.DialUDP(&C.Metadata{ + pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{ NetWork: C.UDP, DstIP: localIP, DstPort: "10001", diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 2a35f096..ccd6a58b 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "fmt" "net" "runtime" @@ -12,7 +13,7 @@ import ( "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/constant/provider" - "github.com/Dreamacro/clash/context" + icontext "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel/statistic" ) @@ -209,14 +210,16 @@ func handleUDPConn(packet *inbound.PacketAdapter) { cond.Broadcast() }() - ctx := context.NewPacketConnContext(metadata) - proxy, rule, err := resolveMetadata(ctx, metadata) + pCtx := icontext.NewPacketConnContext(metadata) + proxy, rule, err := resolveMetadata(pCtx, metadata) if err != nil { log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) return } - rawPc, err := proxy.DialUDP(metadata) + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) + defer cancel() + rawPc, err := proxy.ListenPacketContext(ctx, metadata) if err != nil { if rule == nil { log.Warnln("[UDP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) @@ -225,7 +228,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } return } - ctx.InjectPacketConn(rawPc) + pCtx.InjectPacketConn(rawPc) pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule) switch true { @@ -246,10 +249,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) { }() } -func handleTCPConn(ctx C.ConnContext) { - defer ctx.Conn().Close() +func handleTCPConn(connCtx C.ConnContext) { + defer connCtx.Conn().Close() - metadata := ctx.Metadata() + metadata := connCtx.Metadata() if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) return @@ -260,13 +263,15 @@ func handleTCPConn(ctx C.ConnContext) { return } - proxy, rule, err := resolveMetadata(ctx, metadata) + proxy, rule, err := resolveMetadata(connCtx, metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error()) return } - remoteConn, err := proxy.Dial(metadata) + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) + defer cancel() + remoteConn, err := proxy.DialContext(ctx, metadata) if err != nil { if rule == nil { log.Warnln("[TCP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) @@ -289,7 +294,7 @@ func handleTCPConn(ctx C.ConnContext) { log.Infoln("[TCP] %s --> %s doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.RemoteAddress()) } - handleSocket(ctx, remoteConn) + handleSocket(connCtx, remoteConn) } func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {