fix: trying to let hysteria's port hopping work

This commit is contained in:
wwqgtxx 2022-12-23 11:00:55 +08:00
parent daf0b23805
commit a03af85a6b
4 changed files with 36 additions and 35 deletions

View file

@ -89,7 +89,7 @@ type HysteriaOption struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port"` Port int `proxy:"port,omitempty"`
Ports string `proxy:"ports,omitempty"` Ports string `proxy:"ports,omitempty"`
Protocol string `proxy:"protocol,omitempty"` Protocol string `proxy:"protocol,omitempty"`
ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash
@ -134,12 +134,8 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
Timeout: 8 * time.Second, Timeout: 8 * time.Second,
}, },
} }
var addr string addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
if len(option.Ports) == 0 { ports := option.Ports
addr = net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
} else {
addr = net.JoinHostPort(option.Server, option.Ports)
}
serverName := option.Server serverName := option.Server
if option.SNI != "" { if option.SNI != "" {
@ -244,7 +240,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
down = uint64(option.DownSpeed * mbpsToBps) down = uint64(option.DownSpeed * mbpsToBps)
} }
client, err := core.NewClient( client, err := core.NewClient(
addr, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { addr, ports, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, obfuscator, hopInterval, option.FastOpen, }, obfuscator, hopInterval, option.FastOpen,
) )

View file

@ -58,20 +58,24 @@ type udpPacket struct {
addr net.Addr addr net.Addr
} }
func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (*ObfsUDPHopClientPacketConn, error) { func NewObfsUDPHopClientPacketConn(server string, serverPorts string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (*ObfsUDPHopClientPacketConn, error) {
host, ports, err := parseAddr(server) ports, err := parsePorts(serverPorts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Resolve the server IP address, then attach the ports to UDP addresses // Resolve the server IP address, then attach the ports to UDP addresses
ip, err := dialer.RemoteAddr(host) rAddr, err := dialer.RemoteAddr(server)
if err != nil {
return nil, err
}
ip, _, err := net.SplitHostPort(rAddr.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
serverAddrs := make([]net.Addr, len(ports)) serverAddrs := make([]net.Addr, len(ports))
for i, port := range ports { for i, port := range ports {
serverAddrs[i] = &net.UDPAddr{ serverAddrs[i] = &net.UDPAddr{
IP: net.ParseIP(ip.String()), IP: net.ParseIP(ip),
Port: int(port), Port: int(port),
} }
} }
@ -90,7 +94,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf
}, },
}, },
} }
curConn, err := dialer.ListenPacket(ip) curConn, err := dialer.ListenPacket(rAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,7 +104,7 @@ func NewObfsUDPHopClientPacketConn(server string, hopInterval time.Duration, obf
conn.currentConn = curConn conn.currentConn = curConn
} }
go conn.recvRoutine(conn.currentConn) go conn.recvRoutine(conn.currentConn)
go conn.hopRoutine(dialer, ip) go conn.hopRoutine(dialer, rAddr)
return conn, nil return conn, nil
} }
@ -307,29 +311,25 @@ func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error {
return nil return nil
} }
// parseAddr parses the multi-port server address and returns the host and ports. // parsePorts parses the multi-port server address and returns the host and ports.
// Supports both comma-separated single ports and dash-separated port ranges. // Supports both comma-separated single ports and dash-separated port ranges.
// Format: "host:port1,port2-port3,port4" // Format: "host:port1,port2-port3,port4"
func parseAddr(addr string) (host string, ports []uint16, err error) { func parsePorts(serverPorts string) (ports []uint16, err error) {
host, portStr, err := net.SplitHostPort(addr) portStrs := strings.Split(serverPorts, ",")
if err != nil {
return "", nil, err
}
portStrs := strings.Split(portStr, ",")
for _, portStr := range portStrs { for _, portStr := range portStrs {
if strings.Contains(portStr, "-") { if strings.Contains(portStr, "-") {
// Port range // Port range
portRange := strings.Split(portStr, "-") portRange := strings.Split(portStr, "-")
if len(portRange) != 2 { if len(portRange) != 2 {
return "", nil, net.InvalidAddrError("invalid port range") return nil, net.InvalidAddrError("invalid port range")
} }
start, err := strconv.ParseUint(portRange[0], 10, 16) start, err := strconv.ParseUint(portRange[0], 10, 16)
if err != nil { if err != nil {
return "", nil, net.InvalidAddrError("invalid port range") return nil, net.InvalidAddrError("invalid port range")
} }
end, err := strconv.ParseUint(portRange[1], 10, 16) end, err := strconv.ParseUint(portRange[1], 10, 16)
if err != nil { if err != nil {
return "", nil, net.InvalidAddrError("invalid port range") return nil, net.InvalidAddrError("invalid port range")
} }
if start > end { if start > end {
start, end = end, start start, end = end, start
@ -341,10 +341,13 @@ func parseAddr(addr string) (host string, ports []uint16, err error) {
// Single port // Single port
port, err := strconv.ParseUint(portStr, 10, 16) port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil { if err != nil {
return "", nil, net.InvalidAddrError("invalid port") return nil, net.InvalidAddrError("invalid port")
} }
ports = append(ports, uint16(port)) ports = append(ports, uint16(port))
} }
} }
return host, ports, nil if len(ports) == 0 {
return nil, net.InvalidAddrError("invalid port")
}
return ports, nil
} }

View file

@ -31,6 +31,7 @@ type CongestionFactory func(refBPS uint64) congestion.CongestionControl
type Client struct { type Client struct {
transport *transport.ClientTransport transport *transport.ClientTransport
serverAddr string serverAddr string
serverPorts string
protocol string protocol string
sendBPS, recvBPS uint64 sendBPS, recvBPS uint64
auth []byte auth []byte
@ -51,13 +52,14 @@ type Client struct {
fastOpen bool fastOpen bool
} }
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, func NewClient(serverAddr string, serverPorts string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator obfs.Obfuscator, hopInterval time.Duration, fastOpen bool) (*Client, error) { obfuscator obfs.Obfuscator, hopInterval time.Duration, fastOpen bool) (*Client, error) {
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
c := &Client{ c := &Client{
transport: transport, transport: transport,
serverAddr: serverAddr, serverAddr: serverAddr,
serverPorts: serverPorts,
protocol: protocol, protocol: protocol,
sendBPS: sendBPS, sendBPS: sendBPS,
recvBPS: recvBPS, recvBPS: recvBPS,
@ -73,7 +75,7 @@ func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.C
} }
func (c *Client) connectToServer(dialer utils.PacketDialer) error { func (c *Client) connectToServer(dialer utils.PacketDialer) error {
qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer) qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.serverPorts, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer)
if err != nil { if err != nil {
return err return err
} }

View file

@ -20,7 +20,7 @@ type ClientTransport struct {
Dialer *net.Dialer Dialer *net.Dialer
} }
func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) { func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, serverPorts string, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (net.PacketConn, error) {
server := rAddr.String() server := rAddr.String()
if len(proto) == 0 || proto == "udp" { if len(proto) == 0 || proto == "udp" {
conn, err := dialer.ListenPacket(rAddr) conn, err := dialer.ListenPacket(rAddr)
@ -28,14 +28,14 @@ func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obf
return nil, err return nil, err
} }
if obfs != nil { if obfs != nil {
if isMultiPortAddr(server) { if serverPorts != "" {
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, obfs, dialer) return udp.NewObfsUDPHopClientPacketConn(server, serverPorts, hopInterval, obfs, dialer)
} }
oc := udp.NewObfsUDPConn(conn, obfs) oc := udp.NewObfsUDPConn(conn, obfs)
return oc, nil return oc, nil
} else { } else {
if isMultiPortAddr(server) { if serverPorts != "" {
return udp.NewObfsUDPHopClientPacketConn(server, hopInterval, nil, dialer) return udp.NewObfsUDPHopClientPacketConn(server, serverPorts, hopInterval, nil, dialer)
} }
return conn, nil return conn, nil
} }
@ -65,13 +65,13 @@ func (ct *ClientTransport) quicPacketConn(proto string, rAddr net.Addr, obfs obf
} }
} }
func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (quic.Connection, error) { func (ct *ClientTransport) QUICDial(proto string, server string, serverPorts string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfsPkg.Obfuscator, hopInterval time.Duration, dialer utils.PacketDialer) (quic.Connection, error) {
serverUDPAddr, err := dialer.RemoteAddr(server) serverUDPAddr, err := dialer.RemoteAddr(server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pktConn, err := ct.quicPacketConn(proto, serverUDPAddr, obfs, hopInterval, dialer) pktConn, err := ct.quicPacketConn(proto, serverUDPAddr, serverPorts, obfs, hopInterval, dialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }