diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index b54d328f..92343055 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -53,6 +53,10 @@ func streamConn(c net.Conn, option streamOption) *snell.Snell { // StreamConn implements C.ProxyAdapter func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) + if metadata.NetWork == C.UDP { + err := snell.WriteUDPHeader(c, s.version) + return c, err + } port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) err := snell.WriteHeader(c, metadata.String(), uint(port), s.version) return c, err @@ -100,12 +104,6 @@ func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o // ListenPacketOnStreamConn implements C.ProxyAdapter func (s *Snell) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { - - err = snell.WriteUDPHeader(c, s.version) - if err != nil { - return nil, err - } - pc := snell.PacketConn(c) return newPacketConn(pc, s), nil } diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 65f3cbea..46586673 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -90,6 +90,10 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) return nil, err } + if metadata.NetWork == C.UDP { + err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata)) + return c, err + } err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)) return c, err } @@ -162,11 +166,6 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata, // ListenPacketOnStreamConn implements C.ProxyAdapter func (t *Trojan) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { - err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata)) - if err != nil { - return nil, err - } - pc := t.instance.PacketConn(c) return newPacketConn(pc, t), err }