diff --git a/adapters/outbound/snell.go b/adapters/outbound/snell.go index 61c103b8..edc0ddd7 100644 --- a/adapters/outbound/snell.go +++ b/adapters/outbound/snell.go @@ -16,7 +16,9 @@ import ( type Snell struct { *Base psk []byte + pool *snell.Pool obfsOption *simpleObfsOption + version int } type SnellOption struct { @@ -24,24 +26,47 @@ type SnellOption struct { Server string `proxy:"server"` Port int `proxy:"port"` Psk string `proxy:"psk"` + Version int `proxy:"version,omitempty"` ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` } -func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { - switch s.obfsOption.Mode { +type streamOption struct { + psk []byte + version int + addr string + obfsOption *simpleObfsOption +} + +func streamConn(c net.Conn, option streamOption) *snell.Snell { + switch option.obfsOption.Mode { case "tls": - c = obfs.NewTLSObfs(c, s.obfsOption.Host) + c = obfs.NewTLSObfs(c, option.obfsOption.Host) case "http": - _, port, _ := net.SplitHostPort(s.addr) - c = obfs.NewHTTPObfs(c, s.obfsOption.Host, port) + _, port, _ := net.SplitHostPort(option.addr) + c = obfs.NewHTTPObfs(c, option.obfsOption.Host, port) } - c = snell.StreamConn(c, s.psk) + return snell.StreamConn(c, option.psk, option.version) +} + +func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { + c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) port, _ := strconv.Atoi(metadata.DstPort) - err := snell.WriteHeader(c, metadata.String(), uint(port)) + err := snell.WriteHeader(c, metadata.String(), uint(port), s.version) return c, err } func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + if s.version == snell.Version2 { + c, err := s.pool.Get() + if err != nil { + return nil, err + } + + port, _ := strconv.Atoi(metadata.DstPort) + err = snell.WriteHeader(c, metadata.String(), uint(port), s.version) + return NewConn(c, s), err + } + c, err := dialer.DialContext(ctx, "tcp", s.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %w", s.addr, err) @@ -66,7 +91,15 @@ func NewSnell(option SnellOption) (*Snell, error) { return nil, fmt.Errorf("snell %s obfs mode error: %s", addr, obfsOption.Mode) } - return &Snell{ + // backward compatible + if option.Version == 0 { + option.Version = snell.DefaultSnellVersion + } + if option.Version != snell.Version1 && option.Version != snell.Version2 { + return nil, fmt.Errorf("snell version error: %d", option.Version) + } + + s := &Snell{ Base: &Base{ name: option.Name, addr: addr, @@ -74,5 +107,19 @@ func NewSnell(option SnellOption) (*Snell, error) { }, psk: psk, obfsOption: obfsOption, - }, nil + version: option.Version, + } + + if option.Version == snell.Version2 { + s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) { + c, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + tcpKeepAlive(c) + return streamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil + }) + } + return s, nil } diff --git a/component/pool/pool.go b/component/pool/pool.go new file mode 100644 index 00000000..ce40a129 --- /dev/null +++ b/component/pool/pool.go @@ -0,0 +1,114 @@ +package pool + +import ( + "context" + "runtime" + "time" +) + +type Factory = func(context.Context) (interface{}, error) + +type entry struct { + elm interface{} + time time.Time +} + +type Option func(*pool) + +// WithEvict set the evict callback +func WithEvict(cb func(interface{})) Option { + return func(p *pool) { + p.evict = cb + } +} + +// WithAge defined element max age (millisecond) +func WithAge(maxAge int64) Option { + return func(p *pool) { + p.maxAge = maxAge + } +} + +// WithSize defined max size of Pool +func WithSize(maxSize int) Option { + return func(p *pool) { + p.ch = make(chan interface{}, maxSize) + } +} + +// Pool is for GC, see New for detail +type Pool struct { + *pool +} + +type pool struct { + ch chan interface{} + factory Factory + evict func(interface{}) + maxAge int64 +} + +func (p *pool) GetContext(ctx context.Context) (interface{}, error) { + now := time.Now() + for { + select { + case item := <-p.ch: + elm := item.(*entry) + if p.maxAge != 0 && now.Sub(item.(*entry).time).Milliseconds() > p.maxAge { + if p.evict != nil { + p.evict(elm.elm) + } + continue + } + + return elm.elm, nil + default: + return p.factory(ctx) + } + } +} + +func (p *pool) Get() (interface{}, error) { + return p.GetContext(context.Background()) +} + +func (p *pool) Put(item interface{}) { + e := &entry{ + elm: item, + time: time.Now(), + } + + select { + case p.ch <- e: + return + default: + // pool is full + if p.evict != nil { + p.evict(item) + } + return + } +} + +func recycle(p *Pool) { + for item := range p.pool.ch { + if p.pool.evict != nil { + p.pool.evict(item.(*entry).elm) + } + } +} + +func New(factory Factory, options ...Option) *Pool { + p := &pool{ + ch: make(chan interface{}, 10), + factory: factory, + } + + for _, option := range options { + option(p) + } + + P := &Pool{p} + runtime.SetFinalizer(P, recycle) + return P +} diff --git a/component/pool/pool_test.go b/component/pool/pool_test.go new file mode 100644 index 00000000..844ef245 --- /dev/null +++ b/component/pool/pool_test.go @@ -0,0 +1,96 @@ +package pool + +import ( + "context" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func lg() Factory { + initial := -1 + return func(context.Context) (interface{}, error) { + initial++ + return initial, nil + } +} + +func TestPool_Basic(t *testing.T) { + g := lg() + pool := New(g) + + elm, _ := pool.Get() + assert.Equal(t, 0, elm.(int)) + pool.Put(elm) + elm, _ = pool.Get() + assert.Equal(t, 0, elm.(int)) + elm, _ = pool.Get() + assert.Equal(t, 1, elm.(int)) +} + +func TestPool_MaxSize(t *testing.T) { + g := lg() + size := 5 + pool := New(g, WithSize(size)) + + items := []interface{}{} + + for i := 0; i < size; i++ { + item, _ := pool.Get() + items = append(items, item) + } + + extra, _ := pool.Get() + assert.Equal(t, size, extra.(int)) + + for _, item := range items { + pool.Put(item) + } + + pool.Put(extra) + + for _, item := range items { + elm, _ := pool.Get() + assert.Equal(t, item.(int), elm.(int)) + } +} + +func TestPool_MaxAge(t *testing.T) { + g := lg() + pool := New(g, WithAge(20)) + + elm, _ := pool.Get() + pool.Put(elm) + + elm, _ = pool.Get() + assert.Equal(t, 0, elm.(int)) + pool.Put(elm) + + time.Sleep(time.Millisecond * 22) + elm, _ = pool.Get() + assert.Equal(t, 1, elm.(int)) +} + +func TestPool_AutoGC(t *testing.T) { + g := lg() + + sign := make(chan int) + pool := New(g, WithEvict(func(item interface{}) { + sign <- item.(int) + })) + + elm, _ := pool.Get() + assert.Equal(t, 0, elm.(int)) + pool.Put(2) + + runtime.GC() + + select { + case num := <-sign: + assert.Equal(t, 2, num) + case <-time.After(time.Second * 3): + assert.Fail(t, "something wrong") + } +} diff --git a/component/snell/cipher.go b/component/snell/cipher.go index f66f0801..f778e647 100644 --- a/component/snell/cipher.go +++ b/component/snell/cipher.go @@ -1,21 +1,54 @@ package snell import ( + "crypto/aes" "crypto/cipher" + "github.com/Dreamacro/go-shadowsocks2/shadowaead" "golang.org/x/crypto/argon2" + "golang.org/x/crypto/chacha20poly1305" ) type snellCipher struct { psk []byte + keySize int makeAEAD func(key []byte) (cipher.AEAD, error) } -func (sc *snellCipher) KeySize() int { return 32 } +func (sc *snellCipher) KeySize() int { return sc.keySize } func (sc *snellCipher) SaltSize() int { return 16 } func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) { - return sc.makeAEAD(argon2.IDKey(sc.psk, salt, 3, 8, 1, uint32(sc.KeySize()))) + return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize())) } func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) { - return sc.makeAEAD(argon2.IDKey(sc.psk, salt, 3, 8, 1, uint32(sc.KeySize()))) + return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize())) +} + +func snellKDF(psk, salt []byte, keySize int) []byte { + // snell use a special kdf function + return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize] +} + +func aesGCM(key []byte) (cipher.AEAD, error) { + blk, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewGCM(blk) +} + +func NewAES128GCM(psk []byte) shadowaead.Cipher { + return &snellCipher{ + psk: psk, + keySize: 16, + makeAEAD: aesGCM, + } +} + +func NewChacha20Poly1305(psk []byte) shadowaead.Cipher { + return &snellCipher{ + psk: psk, + keySize: 32, + makeAEAD: chacha20poly1305.New, + } } diff --git a/component/snell/pool.go b/component/snell/pool.go new file mode 100644 index 00000000..37b7de14 --- /dev/null +++ b/component/snell/pool.go @@ -0,0 +1,80 @@ +package snell + +import ( + "context" + "net" + + "github.com/Dreamacro/clash/component/pool" + "github.com/Dreamacro/go-shadowsocks2/shadowaead" +) + +type Pool struct { + pool *pool.Pool +} + +func (p *Pool) Get() (net.Conn, error) { + return p.GetContext(context.Background()) +} + +func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) { + elm, err := p.pool.GetContext(ctx) + if err != nil { + return nil, err + } + + return &PoolConn{elm.(*Snell), p}, nil +} + +func (p *Pool) Put(conn net.Conn) { + if err := HalfClose(conn); err != nil { + conn.Close() + return + } + + p.pool.Put(conn) +} + +type PoolConn struct { + *Snell + pool *Pool +} + +func (pc *PoolConn) Read(b []byte) (int, error) { + // save old status of reply (it mutable by Read) + reply := pc.Snell.reply + + n, err := pc.Snell.Read(b) + if err == shadowaead.ErrZeroChunk { + // if reply is false, it should be client halfclose. + // ignore error and read data again. + if !reply { + pc.Snell.reply = false + return pc.Snell.Read(b) + } + } + return n, err +} + +func (pc *PoolConn) Write(b []byte) (int, error) { + return pc.Snell.Write(b) +} + +func (pc *PoolConn) Close() error { + pc.pool.Put(pc.Snell) + return nil +} + +func NewPool(factory func(context.Context) (*Snell, error)) *Pool { + p := pool.New( + func(ctx context.Context) (interface{}, error) { + return factory(ctx) + }, + pool.WithAge(15000), + pool.WithSize(10), + pool.WithEvict(func(item interface{}) { + item.(*Snell).Close() + }), + ) + + return &Pool{p} +} diff --git a/component/snell/snell.go b/component/snell/snell.go index 920f1225..ecbc90ee 100644 --- a/component/snell/snell.go +++ b/component/snell/snell.go @@ -10,14 +10,21 @@ import ( "sync" "github.com/Dreamacro/go-shadowsocks2/shadowaead" - "golang.org/x/crypto/chacha20poly1305" ) const ( - CommandPing byte = 0 - CommandConnect byte = 1 + Version1 = 1 + Version2 = 2 + DefaultSnellVersion = Version1 +) + +const ( + CommandPing byte = 0 + CommandConnect byte = 1 + CommandConnectV2 byte = 5 CommandTunnel byte = 0 + CommandPong byte = 1 CommandError byte = 2 Version byte = 1 @@ -25,6 +32,7 @@ const ( var ( bufferPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }} + endSignal = []byte{} ) type Snell struct { @@ -70,12 +78,16 @@ func (s *Snell) Read(b []byte) (int, error) { return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) } -func WriteHeader(conn net.Conn, host string, port uint) error { +func WriteHeader(conn net.Conn, host string, port uint, version int) error { buf := bufferPool.Get().(*bytes.Buffer) buf.Reset() defer bufferPool.Put(buf) buf.WriteByte(Version) - buf.WriteByte(CommandConnect) + if version == Version2 { + buf.WriteByte(CommandConnectV2) + } else { + buf.WriteByte(CommandConnect) + } // clientID length & id buf.WriteByte(0) @@ -92,7 +104,24 @@ func WriteHeader(conn net.Conn, host string, port uint) error { return nil } -func StreamConn(conn net.Conn, psk []byte) net.Conn { - cipher := &snellCipher{psk, chacha20poly1305.New} +// HalfClose works only on version2 +func HalfClose(conn net.Conn) error { + if _, err := conn.Write(endSignal); err != nil { + return err + } + + if s, ok := conn.(*Snell); ok { + s.reply = false + } + return nil +} + +func StreamConn(conn net.Conn, psk []byte, version int) *Snell { + var cipher shadowaead.Cipher + if version == Version2 { + cipher = NewAES128GCM(psk) + } else { + cipher = NewChacha20Poly1305(psk) + } return &Snell{Conn: shadowaead.NewConn(conn, cipher)} } diff --git a/go.mod b/go.mod index ff19a68b..f1f85e52 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/Dreamacro/clash go 1.14 require ( - github.com/Dreamacro/go-shadowsocks2 v0.1.6-0.20200722122336-8e5c7db4f96a + github.com/Dreamacro/go-shadowsocks2 v0.1.6 github.com/eapache/queue v1.1.0 // indirect github.com/go-chi/chi v4.1.2+incompatible github.com/go-chi/cors v1.1.1 diff --git a/go.sum b/go.sum index bc588d74..8611148b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/Dreamacro/go-shadowsocks2 v0.1.6-0.20200722122336-8e5c7db4f96a h1:JhQFrFOkCpRB8qsN6PrzHFzjy/8iQpFFk5cbOiplh6s= -github.com/Dreamacro/go-shadowsocks2 v0.1.6-0.20200722122336-8e5c7db4f96a/go.mod h1:LSXCjyHesPY3pLjhwff1mQX72ItcBT/N2xNC685cYeU= +github.com/Dreamacro/go-shadowsocks2 v0.1.6 h1:PysSf9sLT3Qn8jhlin5v7Rk68gOQG4K5BZFY1nxLGxI= +github.com/Dreamacro/go-shadowsocks2 v0.1.6/go.mod h1:LSXCjyHesPY3pLjhwff1mQX72ItcBT/N2xNC685cYeU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -27,7 +27,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -53,7 +52,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed h1:J22ig1FUekjjkmZUM7pTKixYm8DvrYsvrBZdunYeIuQ= golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=