Chore: contexify ProxyAdapter ListenPacket

This commit is contained in:
Dreamacro 2021-10-15 21:44:53 +08:00
parent 583b2a5ace
commit 68753b4ae1
20 changed files with 96 additions and 66 deletions

5
.gitignore vendored
View file

@ -12,7 +12,7 @@ bin/*
# Output of the go coverage tool, specifically when used with LiteIDE # Output of the go coverage tool, specifically when used with LiteIDE
*.out *.out
# dep # go mod vendor
vendor vendor
# GoLand # GoLand
@ -20,3 +20,6 @@ vendor
# macOS file # macOS file
.DS_Store .DS_Store
# test suite
test/config/cache*

View file

@ -36,12 +36,24 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata) conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil { p.alive.Store(err == nil)
p.alive.Store(false)
}
return conn, err return conn, err
} }
// DialUDP implements C.ProxyAdapter
func (p *Proxy) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
return p.ListenPacketContext(ctx, metadata)
}
// ListenPacketContext implements C.ProxyAdapter
func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata)
p.alive.Store(err == nil)
return pc, err
}
// DelayHistory implements C.Proxy // DelayHistory implements C.Proxy
func (p *Proxy) DelayHistory() []C.DelayHistory { func (p *Proxy) DelayHistory() []C.DelayHistory {
queue := p.history.Copy() queue := p.history.Copy()

View file

@ -1,6 +1,7 @@
package outbound package outbound
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -30,8 +31,8 @@ func (b *Base) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return c, errors.New("no support") return c, errors.New("no support")
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (b *Base) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, errors.New("no support") return nil, errors.New("no support")
} }

View file

@ -22,9 +22,9 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
return NewConn(c, d), nil return NewConn(c, d), nil
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := dialer.ListenPacket(context.Background(), "udp", "") pc, err := dialer.ListenPacket(ctx, "udp", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,8 +19,8 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
return NewConn(&NopConn{}, r), nil return NewConn(&NopConn{}, r), nil
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (r *Reject) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, errors.New("match reject rule") return nil, errors.New("match reject rule")
} }

View file

@ -87,9 +87,9 @@ func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_
return NewConn(c, ss), err return NewConn(c, ss), err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := dialer.ListenPacket(context.Background(), "udp", "") pc, err := dialer.ListenPacket(ctx, "udp", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -72,9 +72,9 @@ func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata)
return NewConn(c, ssr), err return NewConn(c, ssr), err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := dialer.ListenPacket(context.Background(), "udp", "") pc, err := dialer.ListenPacket(ctx, "udp", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -76,10 +76,8 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Co
return NewConn(c, ss), nil return NewConn(c, ss), nil
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
c, err := dialer.DialContext(ctx, "tcp", ss.addr) c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
err = fmt.Errorf("%s connect error: %w", ss.addr, err) err = fmt.Errorf("%s connect error: %w", ss.addr, err)
@ -109,7 +107,7 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
return return
} }
pc, err := dialer.ListenPacket(context.Background(), "udp", "") pc, err := dialer.ListenPacket(ctx, "udp", "")
if err != nil { if err != nil {
return return
} }

View file

@ -88,8 +88,8 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
return NewConn(c, t), err return NewConn(c, t), err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Trojan) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var c net.Conn var c net.Conn
// grpc transport // grpc transport
@ -100,8 +100,6 @@ func (t *Trojan) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
} }
defer safeConnClose(c, err) defer safeConnClose(c, err)
} else { } else {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
c, err = dialer.DialContext(ctx, "tcp", t.addr) c, err = dialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)

View file

@ -215,8 +215,8 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, v), err return NewConn(c, v), err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (v *Vmess) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveIP(metadata.Host)
@ -237,8 +237,6 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
c, err = v.client.StreamConn(c, parseVmessAddr(metadata)) c, err = v.client.StreamConn(c, parseVmessAddr(metadata))
} else { } else {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
c, err = dialer.DialContext(ctx, "tcp", v.addr) c, err = dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())

View file

@ -32,10 +32,10 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
return c, err return c, err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (f *Fallback) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
proxy := f.findAliveProxy(true) proxy := f.findAliveProxy(true)
pc, err := proxy.DialUDP(metadata) pc, err := proxy.ListenPacketContext(ctx, metadata)
if err == nil { if err == nil {
pc.AppendToChains(f) pc.AppendToChains(f)
} }

View file

@ -82,8 +82,8 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c
return return
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error) { func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (pc C.PacketConn, err error) {
defer func() { defer func() {
if err == nil { if err == nil {
pc.AppendToChains(lb) pc.AppendToChains(lb)
@ -91,8 +91,7 @@ func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error
}() }()
proxy := lb.Unwrap(metadata) proxy := lb.Unwrap(metadata)
return proxy.ListenPacketContext(ctx, metadata)
return proxy.DialUDP(metadata)
} }
// SupportUDP implements C.ProxyAdapter // SupportUDP implements C.ProxyAdapter

View file

@ -28,9 +28,9 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
return c, err return c, err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (s *Selector) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := s.selectedProxy(true).DialUDP(metadata) pc, err := s.selectedProxy(true).ListenPacketContext(ctx, metadata)
if err == nil { if err == nil {
pc.AppendToChains(s) pc.AppendToChains(s)
} }

View file

@ -42,9 +42,9 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co
return c, err return c, err
} }
// DialUDP implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
pc, err := u.fast(true).DialUDP(metadata) pc, err := u.fast(true).ListenPacketContext(ctx, metadata)
if err == nil { if err == nil {
pc.AppendToChains(u) pc.AppendToChains(u)
} }

View file

@ -29,6 +29,7 @@ const (
const ( const (
DefaultTCPTimeout = 5 * time.Second DefaultTCPTimeout = 5 * time.Second
DefaultUDPTimeout = DefaultTCPTimeout
) )
type Connection interface { type Connection interface {
@ -73,11 +74,14 @@ type PacketConn interface {
type ProxyAdapter interface { type ProxyAdapter interface {
Name() string Name() string
Type() AdapterType Type() AdapterType
Addr() string
SupportUDP() bool
MarshalJSON() ([]byte, error)
// StreamConn wraps a protocol around net.Conn with Metadata. // StreamConn wraps a protocol around net.Conn with Metadata.
// //
// Examples: // Examples:
// conn, _ := net.Dial("tcp", "host:port") // conn, _ := net.DialContext(context.Background(), "tcp", "host:port")
// conn, _ = adapter.StreamConn(conn, metadata) // conn, _ = adapter.StreamConn(conn, metadata)
// //
// It returns a C.Conn with protocol which start with // It returns a C.Conn with protocol which start with
@ -88,10 +92,8 @@ type ProxyAdapter interface {
// contains multiplexing-related reuse logic (if any) // contains multiplexing-related reuse logic (if any)
DialContext(ctx context.Context, metadata *Metadata) (Conn, error) DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
DialUDP(metadata *Metadata) (PacketConn, error) ListenPacketContext(ctx context.Context, metadata *Metadata) (PacketConn, error)
SupportUDP() bool
MarshalJSON() ([]byte, error)
Addr() string
// Unwrap extracts the proxy from a proxy-group. It returns nil when nothing to extract. // Unwrap extracts the proxy from a proxy-group. It returns nil when nothing to extract.
Unwrap(metadata *Metadata) Proxy Unwrap(metadata *Metadata) Proxy
} }
@ -105,9 +107,14 @@ type Proxy interface {
ProxyAdapter ProxyAdapter
Alive() bool Alive() bool
DelayHistory() []DelayHistory DelayHistory() []DelayHistory
Dial(metadata *Metadata) (Conn, error)
LastDelay() uint16 LastDelay() uint16
URLTest(ctx context.Context, url string) (uint16, error) URLTest(ctx context.Context, url string) (uint16, error)
// Deprecated: use DialContext instead.
Dial(metadata *Metadata) (Conn, error)
// Deprecated: use DialPacketConn instead.
DialUDP(metadata *Metadata) (PacketConn, error)
} }
// AdapterType is enum of adapter type // AdapterType is enum of adapter type

View file

@ -67,7 +67,7 @@ type ProxyProvider interface {
Provider Provider
Proxies() []constant.Proxy Proxies() []constant.Proxy
// ProxiesWithTouch is used to inform the provider that the proxy is actually being used while getting the list of proxies. // ProxiesWithTouch is used to inform the provider that the proxy is actually being used while getting the list of proxies.
// Commonly used in Dial and DialUDP // Commonly used in DialContext and DialPacketConn
ProxiesWithTouch() []constant.Proxy ProxiesWithTouch() []constant.Proxy
HealthCheck() HealthCheck()
} }

8
test/Makefile Normal file
View file

@ -0,0 +1,8 @@
lint:
golangci-lint run --disable-all -E govet -E gofumpt -E megacheck ./...
test:
go test -p 1 -v ./...
benchmark:
go test -benchmem -run=^$ -bench .

View file

@ -45,7 +45,7 @@ Prerequisite
* docker (support Linux and macOS) * docker (support Linux and macOS)
``` ```
$ go test -p 1 -v $ make test
``` ```
benchmark (Linux) benchmark (Linux)
@ -55,5 +55,5 @@ benchmark (Linux)
> (change chunkSize to measure the maximum throughput of clash on your machine) > (change chunkSize to measure the maximum throughput of clash on your machine)
``` ```
$ go test -benchmem -run=^$ -bench . $ make benchmark
``` ```

View file

@ -96,6 +96,7 @@ func init() {
images := []string{ images := []string{
ImageShadowsocks, ImageShadowsocks,
ImageShadowsocksRust,
ImageVmess, ImageVmess,
ImageTrojan, ImageTrojan,
ImageSnell, ImageSnell,
@ -582,7 +583,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) {
return return
} }
pc, err := proxy.DialUDP(&C.Metadata{ pc, err := proxy.ListenPacketContext(context.Background(), &C.Metadata{
NetWork: C.UDP, NetWork: C.UDP,
DstIP: localIP, DstIP: localIP,
DstPort: "10001", DstPort: "10001",
@ -595,7 +596,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) {
assert.NoError(t, testPingPongWithPacketConn(t, pc)) assert.NoError(t, testPingPongWithPacketConn(t, pc))
pc, err = proxy.DialUDP(&C.Metadata{ pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{
NetWork: C.UDP, NetWork: C.UDP,
DstIP: localIP, DstIP: localIP,
DstPort: "10001", DstPort: "10001",
@ -608,7 +609,7 @@ func testSuit(t *testing.T, proxy C.ProxyAdapter) {
assert.NoError(t, testLargeDataWithPacketConn(t, pc)) assert.NoError(t, testLargeDataWithPacketConn(t, pc))
pc, err = proxy.DialUDP(&C.Metadata{ pc, err = proxy.ListenPacketContext(context.Background(), &C.Metadata{
NetWork: C.UDP, NetWork: C.UDP,
DstIP: localIP, DstIP: localIP,
DstPort: "10001", DstPort: "10001",

View file

@ -1,6 +1,7 @@
package tunnel package tunnel
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
@ -12,7 +13,7 @@ import (
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/provider" "github.com/Dreamacro/clash/constant/provider"
"github.com/Dreamacro/clash/context" icontext "github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/tunnel/statistic" "github.com/Dreamacro/clash/tunnel/statistic"
) )
@ -209,14 +210,16 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
cond.Broadcast() cond.Broadcast()
}() }()
ctx := context.NewPacketConnContext(metadata) pCtx := icontext.NewPacketConnContext(metadata)
proxy, rule, err := resolveMetadata(ctx, metadata) proxy, rule, err := resolveMetadata(pCtx, metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return return
} }
rawPc, err := proxy.DialUDP(metadata) ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := proxy.ListenPacketContext(ctx, metadata)
if err != nil { if err != nil {
if rule == nil { if rule == nil {
log.Warnln("[UDP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) log.Warnln("[UDP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error())
@ -225,7 +228,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
} }
return return
} }
ctx.InjectPacketConn(rawPc) pCtx.InjectPacketConn(rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule) pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule)
switch true { switch true {
@ -246,10 +249,10 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}() }()
} }
func handleTCPConn(ctx C.ConnContext) { func handleTCPConn(connCtx C.ConnContext) {
defer ctx.Conn().Close() defer connCtx.Conn().Close()
metadata := ctx.Metadata() metadata := connCtx.Metadata()
if !metadata.Valid() { if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata) log.Warnln("[Metadata] not valid: %#v", metadata)
return return
@ -260,13 +263,15 @@ func handleTCPConn(ctx C.ConnContext) {
return return
} }
proxy, rule, err := resolveMetadata(ctx, metadata) proxy, rule, err := resolveMetadata(connCtx, metadata)
if err != nil { if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error()) log.Warnln("[Metadata] parse failed: %s", err.Error())
return return
} }
remoteConn, err := proxy.Dial(metadata) ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
remoteConn, err := proxy.DialContext(ctx, metadata)
if err != nil { if err != nil {
if rule == nil { if rule == nil {
log.Warnln("[TCP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error()) log.Warnln("[TCP] dial %s to %s error: %s", proxy.Name(), metadata.RemoteAddress(), err.Error())
@ -289,7 +294,7 @@ func handleTCPConn(ctx C.ConnContext) {
log.Infoln("[TCP] %s --> %s doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.RemoteAddress()) log.Infoln("[TCP] %s --> %s doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.RemoteAddress())
} }
handleSocket(ctx, remoteConn) handleSocket(connCtx, remoteConn)
} }
func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {