From e54f51af815fc9a11e154b0dbd9f3acfc3d04808 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Fri, 20 Mar 2020 00:02:05 +0800 Subject: [PATCH] Fix: trojan split udp packet --- component/socks5/socks5.go | 6 +- component/trojan/trojan.go | 121 ++++++++++++++++++++++++++++++------- 2 files changed, 102 insertions(+), 25 deletions(-) diff --git a/component/socks5/socks5.go b/component/socks5/socks5.go index b150d1f7..76be1286 100644 --- a/component/socks5/socks5.go +++ b/component/socks5/socks5.go @@ -178,7 +178,7 @@ func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, } command = buf[1] - addr, err = readAddr(rw, buf) + addr, err = ReadAddr(rw, buf) if err != nil { return } @@ -260,10 +260,10 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) ( return nil, err } - return readAddr(rw, buf) + return ReadAddr(rw, buf) } -func readAddr(r io.Reader, b []byte) (Addr, error) { +func ReadAddr(r io.Reader, b []byte) (Addr, error) { if len(b) < MaxAddrLen { return nil, io.ErrShortBuffer } diff --git a/component/trojan/trojan.go b/component/trojan/trojan.go index efb802d5..2ae8d07f 100644 --- a/component/trojan/trojan.go +++ b/component/trojan/trojan.go @@ -7,12 +7,18 @@ import ( "encoding/binary" "encoding/hex" "errors" + "io" "net" "sync" "github.com/Dreamacro/clash/component/socks5" ) +const ( + // max packet length + maxLength = 8192 +) + var ( defaultALPN = []string{"h2", "http/1.1"} crlf = []byte{'\r', '\n'} @@ -62,7 +68,7 @@ func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { return tlsConn, nil } -func (t *Trojan) WriteHeader(conn net.Conn, command Command, socks5Addr []byte) error { +func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { buf := bufPool.Get().(*bytes.Buffer) defer buf.Reset() defer bufPool.Put(buf) @@ -74,15 +80,17 @@ func (t *Trojan) WriteHeader(conn net.Conn, command Command, socks5Addr []byte) buf.Write(socks5Addr) buf.Write(crlf) - _, err := conn.Write(buf.Bytes()) + _, err := w.Write(buf.Bytes()) return err } func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { - return &PacketConn{conn} + return &PacketConn{ + Conn: conn, + } } -func WritePacket(conn net.Conn, socks5Addr, payload []byte) (int, error) { +func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { buf := bufPool.Get().(*bytes.Buffer) defer buf.Reset() defer bufPool.Put(buf) @@ -92,26 +100,67 @@ func WritePacket(conn net.Conn, socks5Addr, payload []byte) (int, error) { buf.Write(crlf) buf.Write(payload) - return conn.Write(buf.Bytes()) + return w.Write(buf.Bytes()) } -func DecodePacket(payload []byte) (net.Addr, []byte, error) { - addr := socks5.SplitAddr(payload) - if addr == nil { - return nil, nil, errors.New("split addr error") +func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { + if len(payload) <= maxLength { + return writePacket(w, socks5Addr, payload) } - buf := payload[len(addr):] - if len(buf) <= 4 { - return nil, nil, errors.New("packet invalid") + offset := 0 + total := len(payload) + for { + cursor := offset + maxLength + if cursor > total { + cursor = total + } + + n, err := writePacket(w, socks5Addr, payload[offset:cursor]) + if err != nil { + return offset + n, err + } + + offset = cursor + if offset == total { + break + } } - length := binary.BigEndian.Uint16(buf[:2]) - if len(buf) < 4+int(length) { - return nil, nil, errors.New("packet invalid") + return total, nil +} + +func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) { + addr, err := socks5.ReadAddr(r, payload) + if err != nil { + return nil, 0, 0, errors.New("read addr error") + } + uAddr := addr.UDPAddr() + + if _, err = io.ReadFull(r, payload[:2]); err != nil { + return nil, 0, 0, errors.New("read length error") } - return addr.UDPAddr(), buf[4 : 4+length], nil + total := int(binary.BigEndian.Uint16(payload[:2])) + if total > maxLength { + return nil, 0, 0, errors.New("packet invalid") + } + + // read crlf + if _, err = io.ReadFull(r, payload[:2]); err != nil { + return nil, 0, 0, errors.New("read crlf error") + } + + length := len(payload) + if total < length { + length = total + } + + if _, err = io.ReadFull(r, payload[:length]); err != nil { + return nil, 0, 0, errors.New("read packet error") + } + + return uAddr, length, total - length, nil } func New(option *Option) *Trojan { @@ -120,6 +169,9 @@ func New(option *Option) *Trojan { type PacketConn struct { net.Conn + remain int + rAddr net.Addr + mux sync.Mutex } func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { @@ -127,14 +179,39 @@ func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { } func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := pc.Conn.Read(b) - addr, payload, err := DecodePacket(b) - if err != nil { - return n, nil, err + pc.mux.Lock() + defer pc.mux.Unlock() + if pc.remain != 0 { + length := len(b) + if pc.remain < length { + length = pc.remain + } + + n, err := pc.Conn.Read(b[:length]) + if err != nil { + return 0, nil, err + } + + pc.remain -= n + addr := pc.rAddr + if pc.remain == 0 { + pc.rAddr = nil + } + + return n, addr, nil } - copy(b, payload) - return len(payload), addr, nil + addr, n, remain, err := ReadPacket(pc.Conn, b) + if err != nil { + return 0, nil, err + } + + if remain != 0 { + pc.remain = remain + pc.rAddr = addr + } + + return n, addr, nil } func hexSha224(data []byte) []byte {