diff --git a/common/net/packet.go b/common/net/packet.go index 865590ce..9afe86b0 100644 --- a/common/net/packet.go +++ b/common/net/packet.go @@ -1,9 +1,6 @@ package net import ( - "net" - "sync" - "github.com/Dreamacro/clash/common/net/deadline" "github.com/Dreamacro/clash/common/net/packet" ) @@ -11,30 +8,10 @@ import ( type EnhancePacketConn = packet.EnhancePacketConn var NewEnhancePacketConn = packet.NewEnhancePacketConn +var NewThreadSafePacketConn = packet.NewThreadSafePacketConn +var NewRefPacketConn = packet.NewRefPacketConn + var NewDeadlineNetPacketConn = deadline.NewNetPacketConn var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn var NewDeadlineSingPacketConn = deadline.NewSingPacketConn var NewDeadlineEnhanceSingPacketConn = deadline.NewEnhanceSingPacketConn - -type threadSafePacketConn struct { - EnhancePacketConn - access sync.Mutex -} - -func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - c.access.Lock() - defer c.access.Unlock() - return c.EnhancePacketConn.WriteTo(b, addr) -} - -func (c *threadSafePacketConn) Upstream() any { - return c.EnhancePacketConn -} - -func (c *threadSafePacketConn) ReaderReplaceable() bool { - return true -} - -func NewThreadSafePacketConn(pc net.PacketConn) net.PacketConn { - return &threadSafePacketConn{EnhancePacketConn: NewEnhancePacketConn(pc)} -} diff --git a/common/net/packet/packet_sing.go b/common/net/packet/packet_sing.go index daa352c8..cfcf5ed0 100644 --- a/common/net/packet/packet_sing.go +++ b/common/net/packet/packet_sing.go @@ -12,13 +12,13 @@ import ( type SingPacketConn = N.NetPacketConn type EnhanceSingPacketConn interface { - N.NetPacketConn + SingPacketConn EnhancePacketConn } type enhanceSingPacketConn struct { - N.NetPacketConn - readWaiter N.PacketReadWaiter + SingPacketConn + packetReadWaiter N.PacketReadWaiter } func (c *enhanceSingPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { @@ -28,12 +28,12 @@ func (c *enhanceSingPacketConn) WaitReadFrom() (data []byte, put func(), addr ne buff = buf.NewPacket() // do not use stack buffer return buff } - if c.readWaiter != nil { - c.readWaiter.InitializeReadWaiter(newBuffer) - defer c.readWaiter.InitializeReadWaiter(nil) - dest, err = c.readWaiter.WaitReadPacket() + if c.packetReadWaiter != nil { + c.packetReadWaiter.InitializeReadWaiter(newBuffer) + defer c.packetReadWaiter.InitializeReadWaiter(nil) + dest, err = c.packetReadWaiter.WaitReadPacket() } else { - dest, err = c.NetPacketConn.ReadPacket(newBuffer()) + dest, err = c.SingPacketConn.ReadPacket(newBuffer()) } if dest.IsFqdn() { addr = dest @@ -59,7 +59,7 @@ func (c *enhanceSingPacketConn) WaitReadFrom() (data []byte, put func(), addr ne } func (c *enhanceSingPacketConn) Upstream() any { - return c.NetPacketConn + return c.SingPacketConn } func (c *enhanceSingPacketConn) WriterReplaceable() bool { @@ -70,10 +70,10 @@ func (c *enhanceSingPacketConn) ReaderReplaceable() bool { return true } -func newEnhanceSingPacketConn(conn N.NetPacketConn) *enhanceSingPacketConn { - epc := &enhanceSingPacketConn{NetPacketConn: conn} +func newEnhanceSingPacketConn(conn SingPacketConn) *enhanceSingPacketConn { + epc := &enhanceSingPacketConn{SingPacketConn: conn} if readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn); isReadWaiter { - epc.readWaiter = readWaiter + epc.packetReadWaiter = readWaiter } return epc } diff --git a/common/net/packet/ref.go b/common/net/packet/ref.go new file mode 100644 index 00000000..a562b2e2 --- /dev/null +++ b/common/net/packet/ref.go @@ -0,0 +1,75 @@ +package packet + +import ( + "net" + "runtime" + "time" +) + +type refPacketConn struct { + pc EnhancePacketConn + ref any +} + +func (c *refPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + defer runtime.KeepAlive(c.ref) + return c.pc.WaitReadFrom() +} + +func (c *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + defer runtime.KeepAlive(c.ref) + return c.pc.ReadFrom(p) +} + +func (c *refPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + defer runtime.KeepAlive(c.ref) + return c.pc.WriteTo(p, addr) +} + +func (c *refPacketConn) Close() error { + defer runtime.KeepAlive(c.ref) + return c.pc.Close() +} + +func (c *refPacketConn) LocalAddr() net.Addr { + defer runtime.KeepAlive(c.ref) + return c.pc.LocalAddr() +} + +func (c *refPacketConn) SetDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.pc.SetDeadline(t) +} + +func (c *refPacketConn) SetReadDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.pc.SetReadDeadline(t) +} + +func (c *refPacketConn) SetWriteDeadline(t time.Time) error { + defer runtime.KeepAlive(c.ref) + return c.pc.SetWriteDeadline(t) +} + +func (c *refPacketConn) Upstream() any { + return c.pc +} + +func (c *refPacketConn) ReaderReplaceable() bool { // Relay() will handle reference + return true +} + +func (c *refPacketConn) WriterReplaceable() bool { // Relay() will handle reference + return true +} + +func NewRefPacketConn(pc net.PacketConn, ref any) EnhancePacketConn { + rPC := &refPacketConn{pc: NewEnhancePacketConn(pc), ref: ref} + if singPC, isSingPC := pc.(SingPacketConn); isSingPC { + return &refSingPacketConn{ + refPacketConn: rPC, + singPacketConn: singPC, + } + } + return rPC +} diff --git a/common/net/packet/ref_sing.go b/common/net/packet/ref_sing.go new file mode 100644 index 00000000..2ca955fa --- /dev/null +++ b/common/net/packet/ref_sing.go @@ -0,0 +1,26 @@ +package packet + +import ( + "runtime" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type refSingPacketConn struct { + *refPacketConn + singPacketConn SingPacketConn +} + +var _ N.NetPacketConn = (*refSingPacketConn)(nil) + +func (c *refSingPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer runtime.KeepAlive(c.ref) + return c.singPacketConn.WritePacket(buffer, destination) +} + +func (c *refSingPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + defer runtime.KeepAlive(c.ref) + return c.singPacketConn.ReadPacket(buffer) +} diff --git a/common/net/packet/thread.go b/common/net/packet/thread.go new file mode 100644 index 00000000..14d64233 --- /dev/null +++ b/common/net/packet/thread.go @@ -0,0 +1,36 @@ +package packet + +import ( + "net" + "sync" +) + +type threadSafePacketConn struct { + EnhancePacketConn + access sync.Mutex +} + +func (c *threadSafePacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + c.access.Lock() + defer c.access.Unlock() + return c.EnhancePacketConn.WriteTo(b, addr) +} + +func (c *threadSafePacketConn) Upstream() any { + return c.EnhancePacketConn +} + +func (c *threadSafePacketConn) ReaderReplaceable() bool { + return true +} + +func NewThreadSafePacketConn(pc net.PacketConn) EnhancePacketConn { + tsPC := &threadSafePacketConn{EnhancePacketConn: NewEnhancePacketConn(pc)} + if singPC, isSingPC := pc.(SingPacketConn); isSingPC { + return &threadSafeSingPacketConn{ + threadSafePacketConn: tsPC, + singPacketConn: singPC, + } + } + return tsPC +} diff --git a/common/net/packet/thread_sing.go b/common/net/packet/thread_sing.go new file mode 100644 index 00000000..0869a512 --- /dev/null +++ b/common/net/packet/thread_sing.go @@ -0,0 +1,24 @@ +package packet + +import ( + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type threadSafeSingPacketConn struct { + *threadSafePacketConn + singPacketConn SingPacketConn +} + +var _ N.NetPacketConn = (*threadSafeSingPacketConn)(nil) + +func (c *threadSafeSingPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + c.access.Lock() + defer c.access.Unlock() + return c.singPacketConn.WritePacket(buffer, destination) +} + +func (c *threadSafeSingPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + return c.singPacketConn.ReadPacket(buffer) +} diff --git a/common/net/refconn.go b/common/net/refconn.go index 0f32ebc1..5caaebc8 100644 --- a/common/net/refconn.go +++ b/common/net/refconn.go @@ -80,64 +80,3 @@ var _ ExtendedConn = (*refConn)(nil) func NewRefConn(conn net.Conn, ref any) net.Conn { return &refConn{conn: NewExtendedConn(conn), ref: ref} } - -type refPacketConn struct { - pc EnhancePacketConn - ref any -} - -func (pc *refPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { - defer runtime.KeepAlive(pc.ref) - return pc.pc.WaitReadFrom() -} - -func (pc *refPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - defer runtime.KeepAlive(pc.ref) - return pc.pc.ReadFrom(p) -} - -func (pc *refPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - defer runtime.KeepAlive(pc.ref) - return pc.pc.WriteTo(p, addr) -} - -func (pc *refPacketConn) Close() error { - defer runtime.KeepAlive(pc.ref) - return pc.pc.Close() -} - -func (pc *refPacketConn) LocalAddr() net.Addr { - defer runtime.KeepAlive(pc.ref) - return pc.pc.LocalAddr() -} - -func (pc *refPacketConn) SetDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.ref) - return pc.pc.SetDeadline(t) -} - -func (pc *refPacketConn) SetReadDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.ref) - return pc.pc.SetReadDeadline(t) -} - -func (pc *refPacketConn) SetWriteDeadline(t time.Time) error { - defer runtime.KeepAlive(pc.ref) - return pc.pc.SetWriteDeadline(t) -} - -func (pc *refPacketConn) Upstream() any { - return pc.pc -} - -func (pc *refPacketConn) ReaderReplaceable() bool { // Relay() will handle reference - return true -} - -func (pc *refPacketConn) WriterReplaceable() bool { // Relay() will handle reference - return true -} - -func NewRefPacketConn(pc net.PacketConn, ref any) net.PacketConn { - return &refPacketConn{pc: NewEnhancePacketConn(pc), ref: ref} -}