From d47ce79a24d78256278990ae5fe3e6521ffb5277 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 25 Nov 2022 11:32:05 +0800 Subject: [PATCH] chore: better tuic conn close --- adapter/outbound/tuic.go | 6 ++++++ transport/tuic/client.go | 31 ++++++++++++++++++++----------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 1cfba568..88dbf03d 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "os" + "runtime" "strconv" "sync" "time" @@ -199,6 +200,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { RequestTimeout: option.RequestTimeout, } clientMap[o] = client + runtime.SetFinalizer(client, closeTuicClient) return client } @@ -214,3 +216,7 @@ func NewTuic(option TuicOption) (*Tuic, error) { getClient: getClient, }, nil } + +func closeTuicClient(client *tuic.Client) { + client.Close(nil) +} diff --git a/transport/tuic/client.go b/transport/tuic/client.go index 8f91ef63..b6bb0bf7 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/client.go @@ -197,18 +197,26 @@ func (t *Client) deferQuicConn(quicConn quic.Connection, err error) { t.connMutex.Lock() defer t.connMutex.Unlock() if t.quicConn == quicConn { - t.udpInputMap.Range(func(key, value any) bool { - if conn, ok := value.(net.Conn); ok { - _ = conn.Close() - } - return true - }) - t.udpInputMap = sync.Map{} // new one - t.quicConn = nil + t.Close(err) } } } +func (t *Client) Close(err error) { + quicConn := t.quicConn + if quicConn != nil { + _ = t.quicConn.CloseWithError(ProtocolError, err.Error()) + t.udpInputMap.Range(func(key, value any) bool { + if conn, ok := value.(net.Conn); ok { + _ = conn.Close() + } + return true + }) + t.udpInputMap = sync.Map{} // new one + t.quicConn = nil + } +} + func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) { quicConn, err := t.getQuicConn(ctx, dialFn) if err != nil { @@ -237,7 +245,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f if t.RequestTimeout > 0 { _ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond)) } - conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr()}) + conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t}) response, err := ReadResponse(conn) if err != nil { return nil, err @@ -252,8 +260,9 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f type quicStreamConn struct { quic.Stream - lAddr net.Addr - rAddr net.Addr + lAddr net.Addr + rAddr net.Addr + client *Client } func (q *quicStreamConn) LocalAddr() net.Addr {