diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 01d161ec..3dac52da 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -3,11 +3,13 @@ package outbound import ( "context" "crypto/tls" + "encoding/binary" "errors" "fmt" "net" "net/http" "strconv" + "sync" "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/resolver" @@ -18,6 +20,11 @@ import ( "golang.org/x/net/http2" ) +const ( + // max packet length + maxLength = 8192 +) + type Vless struct { *Base client *vless.Client @@ -280,17 +287,85 @@ func parseVlessAddr(metadata *C.Metadata) *vless.DstAddr { type vlessPacketConn struct { net.Conn rAddr net.Addr + remain int + mux sync.Mutex + cache []byte } -func (uc *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - return uc.Conn.Write(b) +func (c *vlessPacketConn) writePacket(b []byte, addr net.Addr) (int, error) { + length := len(b) + defer func() { + c.cache = c.cache[:0] + }() + c.cache = append(c.cache, byte(length>>8), byte(length)) + c.cache = append(c.cache, b...) + n, err := c.Conn.Write(c.cache) + if n > 2 { + return n - 2, err + } + + return 0, err } -func (uc *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := uc.Conn.Read(b) - return n, uc.rAddr, err +func (c *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if len(b) <= maxLength { + return c.writePacket(b, addr) + } + + offset := 0 + total := len(b) + for offset < total { + cursor := offset + maxLength + if cursor > total { + cursor = total + } + + n, err := c.writePacket(b[offset:cursor], addr) + if err != nil { + return offset + n, err + } + + offset = cursor + } + + return total, nil } + +func (c *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + c.mux.Lock() + defer c.mux.Unlock() + + length := len(b) + if c.remain > 0 { + if c.remain < length { + length = c.remain + } + + n, err := c.Conn.Read(b[:length]) + if err != nil { + return 0, nil, err + } + + c.remain -= n + return n, c.rAddr, nil + } + + var packetLength uint16 + if err := binary.Read(c.Conn, binary.BigEndian, &packetLength); err != nil { + return 0, nil, err + } + + remain := int(packetLength) + n, err := c.Conn.Read(b[:length]) + remain -= n + if remain > 0 { + c.remain = remain + } + return n, c.rAddr, err +} + + func NewVless(option VlessOption) (*Vless, error) { if !option.TLS && option.Network == "grpc" { return nil, fmt.Errorf("TLS must be true with vless-grpc")