package snell import ( "encoding/binary" "errors" "fmt" "io" "net" "sync" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/shadowsocks/shadowaead" "github.com/Dreamacro/clash/transport/socks5" ) const ( Version1 = 1 Version2 = 2 Version3 = 3 DefaultSnellVersion = Version1 // max packet length maxLength = 0x3FFF ) const ( CommandPing byte = 0 CommandConnect byte = 1 CommandConnectV2 byte = 5 CommandUDP byte = 6 CommondUDPForward byte = 1 CommandTunnel byte = 0 CommandPong byte = 1 CommandError byte = 2 Version byte = 1 ) var endSignal = []byte{} type Snell struct { net.Conn buffer [1]byte reply bool } func (s *Snell) Read(b []byte) (int, error) { if s.reply { return s.Conn.Read(b) } s.reply = true if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { return 0, err } if s.buffer[0] == CommandTunnel { return s.Conn.Read(b) } else if s.buffer[0] != CommandError { return 0, errors.New("command not support") } // CommandError // 1 byte error code if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { return 0, err } errcode := int(s.buffer[0]) // 1 byte error message length if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil { return 0, err } length := int(s.buffer[0]) msg := make([]byte, length) if _, err := io.ReadFull(s.Conn, msg); err != nil { return 0, err } return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) } func WriteHeader(conn net.Conn, host string, port uint, version int) error { buf := pool.GetBuffer() defer pool.PutBuffer(buf) buf.WriteByte(Version) if version == Version2 { buf.WriteByte(CommandConnectV2) } else { buf.WriteByte(CommandConnect) } // clientID length & id buf.WriteByte(0) // host & port buf.WriteByte(uint8(len(host))) buf.WriteString(host) binary.Write(buf, binary.BigEndian, uint16(port)) if _, err := conn.Write(buf.Bytes()); err != nil { return err } return nil } func WriteUDPHeader(conn net.Conn, version int) error { if version < Version3 { return errors.New("unsupport UDP version") } // version, command, clientID length _, err := conn.Write([]byte{Version, CommandUDP, 0x00}) return err } // HalfClose works only on version2 func HalfClose(conn net.Conn) error { if _, err := conn.Write(endSignal); err != nil { return err } if s, ok := conn.(*Snell); ok { s.reply = false } return nil } func StreamConn(conn net.Conn, psk []byte, version int) *Snell { var cipher shadowaead.Cipher if version != Version1 { cipher = NewAES128GCM(psk) } else { cipher = NewChacha20Poly1305(psk) } return &Snell{Conn: shadowaead.NewConn(conn, cipher)} } func PacketConn(conn net.Conn) net.PacketConn { return &packetConn{ Conn: conn, } } func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { buf := pool.GetBuffer() defer pool.PutBuffer(buf) // compose snell UDP address format (refer: icpz/snell-server-reversed) // a brand new wheel to replace socks5 address format, well done Yachen buf.WriteByte(CommondUDPForward) switch socks5Addr[0] { case socks5.AtypDomainName: hostLen := socks5Addr[1] buf.Write(socks5Addr[1 : 1+1+hostLen+2]) case socks5.AtypIPv4: buf.Write([]byte{0x00, 0x04}) buf.Write(socks5Addr[1 : 1+net.IPv4len+2]) case socks5.AtypIPv6: buf.Write([]byte{0x00, 0x06}) buf.Write(socks5Addr[1 : 1+net.IPv6len+2]) } buf.Write(payload) _, err := w.Write(buf.Bytes()) if err != nil { return 0, err } return len(payload), nil } func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) { if len(payload) <= maxLength { return writePacket(w, socks5Addr, payload) } offset := 0 total := len(payload) for { cursor := offset + maxLength if cursor > total { cursor = total } n, err := writePacket(w, socks5Addr, payload[offset:cursor]) if err != nil { return offset + n, err } offset = cursor if offset == total { break } } return total, nil } func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) { buf := pool.Get(pool.UDPBufferSize) defer pool.Put(buf) n, err := r.Read(buf) headLen := 1 if err != nil { return nil, 0, err } if n < headLen { return nil, 0, errors.New("insufficient UDP length") } // parse snell UDP response address format switch buf[0] { case 0x04: headLen += net.IPv4len + 2 if n < headLen { err = errors.New("insufficient UDP length") break } buf[0] = socks5.AtypIPv4 case 0x06: headLen += net.IPv6len + 2 if n < headLen { err = errors.New("insufficient UDP length") break } buf[0] = socks5.AtypIPv6 default: err = errors.New("ip version invalid") } if err != nil { return nil, 0, err } addr := socks5.SplitAddr(buf[0:]) if addr == nil { return nil, 0, errors.New("remote address invalid") } uAddr := addr.UDPAddr() if uAddr == nil { return nil, 0, errors.New("parse addr error") } length := len(payload) if n-headLen < length { length = n - headLen } copy(payload[:], buf[headLen:headLen+length]) return uAddr, length, nil } type packetConn struct { net.Conn rMux sync.Mutex wMux sync.Mutex } func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { pc.wMux.Lock() defer pc.wMux.Unlock() return WritePacket(pc, socks5.ParseAddr(addr.String()), b) } func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { pc.rMux.Lock() defer pc.rMux.Unlock() addr, n, err := ReadPacket(pc.Conn, b) if err != nil { return 0, nil, err } return n, addr, nil }