chore: use early conn to support real ws 0-rtt

This commit is contained in:
wwqgtxx 2023-02-24 09:54:54 +08:00
parent a1d008e6f0
commit 75680c5866
13 changed files with 132 additions and 40 deletions

View file

@ -53,6 +53,9 @@ func (rw *nopConn) Read(b []byte) (int, error) {
}
func (rw *nopConn) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
return 0, io.EOF
}

View file

@ -103,9 +103,9 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e
}
}
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443"))
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil
}
return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}
// DialContext implements C.ProxyAdapter

View file

@ -213,12 +213,12 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
}
if metadata.NetWork == C.UDP {
if v.option.XUDP {
return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
} else {
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}
} else {
return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}
}
@ -289,9 +289,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
}(c)
if v.option.XUDP {
c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
c = v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
c = v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}
if err != nil {

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/provider"
@ -30,11 +31,21 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...)
if err == nil {
c.AppendToChains(f)
f.onDialSuccess()
} else {
f.onDialFailed(proxy.Type(), err)
}
c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
if err == nil {
f.onDialSuccess()
} else {
f.onDialFailed(proxy.Type(), err)
}
},
}
return c, err
}

View file

@ -10,6 +10,7 @@ import (
"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/common/murmur3"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
@ -83,17 +84,24 @@ func jumpHash(key uint64, buckets int32) int32 {
// DialContext implements C.ProxyAdapter
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
proxy := lb.Unwrap(metadata, true)
defer func() {
if err == nil {
c.AppendToChains(lb)
lb.onDialSuccess()
} else {
lb.onDialFailed(proxy.Type(), err)
}
}()
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
if err == nil {
c.AppendToChains(lb)
} else {
lb.onDialFailed(proxy.Type(), err)
}
c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
if err == nil {
lb.onDialSuccess()
} else {
lb.onDialFailed(proxy.Type(), err)
}
},
}
return
}

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/common/singledo"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
@ -38,10 +39,20 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
if err == nil {
c.AppendToChains(u)
u.onDialSuccess()
} else {
u.onDialFailed(proxy.Type(), err)
}
c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
if err == nil {
u.onDialSuccess()
} else {
u.onDialFailed(proxy.Type(), err)
}
},
}
return c, err
}

View file

@ -0,0 +1,25 @@
package callback
import (
C "github.com/Dreamacro/clash/constant"
)
type FirstWriteCallBackConn struct {
C.Conn
Callback func(error)
written bool
}
func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) {
defer func() {
if !c.written {
c.written = true
c.Callback(err)
}
}()
return c.Conn.Write(b)
}
func (c *FirstWriteCallBackConn) Upstream() any {
return c.Conn
}

View file

@ -12,13 +12,14 @@ var _ ExtendedConn = (*BufferedConn)(nil)
type BufferedConn struct {
r *bufio.Reader
ExtendedConn
peeked bool
}
func NewBufferedConn(c net.Conn) *BufferedConn {
if bc, ok := c.(*BufferedConn); ok {
return bc
}
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c)}
return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c), false}
}
// Reader returns the internal bufio.Reader.
@ -26,11 +27,20 @@ func (c *BufferedConn) Reader() *bufio.Reader {
return c.r
}
func (c *BufferedConn) Peeked() bool {
return c.peeked
}
// Peek returns the next n bytes without advancing the reader.
func (c *BufferedConn) Peek(n int) ([]byte, error) {
c.peeked = true
return c.r.Peek(n)
}
func (c *BufferedConn) Discard(n int) (discarded int, err error) {
return c.r.Discard(n)
}
func (c *BufferedConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}

View file

@ -36,12 +36,7 @@ type SnifferDispatcher struct {
parsePureIp bool
}
func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*N.BufferedConn)
if !ok {
return
}
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) {
if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) {
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
if err != nil {
@ -74,7 +69,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
}
sd.rwMux.RUnlock()
if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
if host, err := sd.sniffDomain(conn, metadata); err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return

View file

@ -3,6 +3,8 @@ package constant
import (
"net"
N "github.com/Dreamacro/clash/common/net"
"github.com/gofrs/uuid"
)
@ -13,7 +15,7 @@ type PlainContext interface {
type ConnContext interface {
PlainContext
Metadata() *Metadata
Conn() net.Conn
Conn() *N.BufferedConn
}
type PacketConnContext interface {

View file

@ -12,7 +12,7 @@ import (
type ConnContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
conn *N.BufferedConn
}
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
@ -36,6 +36,6 @@ func (c *ConnContext) Metadata() *C.Metadata {
}
// Conn implement C.ConnContext Conn
func (c *ConnContext) Conn() net.Conn {
func (c *ConnContext) Conn() *N.BufferedConn {
return c.conn
}

View file

@ -7,7 +7,6 @@ import (
"io"
"net"
"sync"
"time"
"github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net"
@ -208,12 +207,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
}
}
go func() {
select {
case <-c.handshake:
case <-time.After(200 * time.Millisecond):
c.sendRequest(nil)
}
}()
//go func() {
// select {
// case <-c.handshake:
// case <-time.After(200 * time.Millisecond):
// c.sendRequest(nil)
// }
//}()
return c, nil
}

View file

@ -366,8 +366,20 @@ func handleTCPConn(connCtx C.ConnContext) {
return
}
conn := connCtx.Conn()
if sniffer.Dispatcher.Enable() && sniffingEnable {
sniffer.Dispatcher.TCPSniff(connCtx.Conn(), metadata)
sniffer.Dispatcher.TCPSniff(conn, metadata)
}
peekMutex := sync.Mutex{}
if !conn.Peeked() {
peekMutex.Lock()
go func() {
defer peekMutex.Unlock()
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, _ = conn.Peek(1)
_ = conn.SetReadDeadline(time.Time{})
}()
}
proxy, rule, err := resolveMetadata(connCtx, metadata)
@ -387,10 +399,26 @@ func handleTCPConn(connCtx C.ConnContext) {
}
}
var peekBytes []byte
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) {
return proxy.DialContext(ctx, dialMetadata)
remoteConn, err := proxy.DialContext(ctx, dialMetadata)
if err != nil {
return nil, err
}
peekMutex.Lock()
defer peekMutex.Unlock()
peekBytes, _ = conn.Peek(conn.Buffered())
_, err = remoteConn.Write(peekBytes)
if err != nil {
return nil, err
}
if peekLen := len(peekBytes); peekLen > 0 {
_, _ = conn.Discard(peekLen)
}
return remoteConn, err
}, func(err error) {
if rule == nil {
log.Warnln(