Feature: add vmess WebSocket early data (#1505)

Co-authored-by: ShinyGwyn <79344143+ShinyGwyn@users.noreply.github.com>
This commit is contained in:
秋のかえで 2021-08-22 00:25:29 +08:00 committed by GitHub
parent c6d375eda2
commit 0267b2efad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 18 deletions

View file

@ -43,6 +43,7 @@ type VmessOption struct {
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
@ -64,19 +65,35 @@ type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
}
type WSOptions struct {
Path string `proxy:"path,omitempty"`
Headers map[string]string `proxy:"headers,omitempty"`
MaxEarlyData int `proxy:"max-early-data,omitempty"`
EarlyDataHeaderName string `proxy:"early-data-header-name,omitempty"`
}
// StreamConn implements C.ProxyAdapter
func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
switch v.option.Network {
case "ws":
if v.option.WSOpts.Path == "" {
v.option.WSOpts.Path = v.option.WSPath
}
if len(v.option.WSOpts.Headers) == 0 {
v.option.WSOpts.Headers = v.option.WSHeaders
}
host, port, _ := net.SplitHostPort(v.addr)
wsOpts := &vmess.WebsocketConfig{
Host: host,
Port: port,
Path: v.option.WSPath,
Path: v.option.WSOpts.Path,
MaxEarlyData: v.option.WSOpts.MaxEarlyData,
EarlyDataHeaderName: v.option.WSOpts.EarlyDataHeaderName,
}
if len(v.option.WSHeaders) != 0 {
if len(v.option.WSOpts.Headers) != 0 {
header := http.Header{}
for key, value := range v.option.WSHeaders {
header.Add(key, value)

View file

@ -59,12 +59,12 @@ func (vc *Conn) Read(b []byte) (int, error) {
func (vc *Conn) sendRequest() error {
timestamp := time.Now()
mbuf := &bytes.Buffer{}
if !vc.isAead {
h := hmac.New(md5.New, vc.id.UUID.Bytes())
binary.Write(h, binary.BigEndian, uint64(timestamp.Unix()))
if _, err := vc.Conn.Write(h.Sum(nil)); err != nil {
return err
}
mbuf.Write(h.Sum(nil))
}
buf := &bytes.Buffer{}
@ -110,7 +110,8 @@ func (vc *Conn) sendRequest() error {
stream := cipher.NewCFBEncrypter(block, hashTimestamp(timestamp))
stream.XORKeyStream(buf.Bytes(), buf.Bytes())
_, err = vc.Conn.Write(buf.Bytes())
mbuf.Write(buf.Bytes())
_, err = vc.Conn.Write(mbuf.Bytes())
return err
}

View file

@ -1,7 +1,11 @@
package vmess
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
@ -23,6 +27,15 @@ type websocketConn struct {
rMux sync.Mutex
wMux sync.Mutex
}
type websocketWithEarlyDataConn struct {
net.Conn
underlay net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
config *WebsocketConfig
}
type WebsocketConfig struct {
Host string
@ -32,6 +45,8 @@ type WebsocketConfig struct {
TLS bool
SkipCertVerify bool
ServerName string
MaxEarlyData int
EarlyDataHeaderName string
}
// Read implements net.Conn.Read()
@ -113,7 +128,121 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
return wsc.conn.SetWriteDeadline(t)
}
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
earlyDataBuf := bytes.NewBuffer(nil)
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf)
earlydata := bytes.NewReader(earlyData)
limitedEarlyDatareader := io.LimitReader(earlydata, int64(wsedc.config.MaxEarlyData))
n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader)
if encerr != nil {
return errors.New("failed to encode early data: " + encerr.Error())
}
if errc := base64EarlyDataEncoder.Close(); errc != nil {
return errors.New("failed to encode early data tail: " + errc.Error())
}
var err error
if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, earlyDataBuf); err != nil {
wsedc.Close()
return errors.New("failed to dial WebSocket: " + err.Error())
}
wsedc.dialed <- true
if n != int64(len(earlyData)) {
_, err = wsedc.Conn.Write(earlyData[n:])
}
return err
}
func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(b); err != nil {
return 0, err
}
return len(b), nil
}
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
}
if wsedc.Conn == nil {
select {
case <-wsedc.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-wsedc.dialed:
}
}
return wsedc.Conn.Read(b)
}
func (wsedc *websocketWithEarlyDataConn) Close() error {
wsedc.closed = true
wsedc.cancel()
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.Close()
}
func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.LocalAddr()
}
return wsedc.Conn.LocalAddr()
}
func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
if wsedc.Conn == nil {
return wsedc.underlay.RemoteAddr()
}
return wsedc.Conn.RemoteAddr()
}
func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
if err := wsedc.SetReadDeadline(t); err != nil {
return err
}
return wsedc.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetReadDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
if wsedc.Conn == nil {
return nil
}
return wsedc.Conn.SetWriteDeadline(t)
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
underlay: conn,
config: c,
}
return conn, nil
}
func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
dialer := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
return conn, nil
@ -152,6 +281,14 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
}
}
if earlyData != nil {
if c.EarlyDataHeaderName == "" {
uri.Path += earlyData.String()
} else {
headers.Set(c.EarlyDataHeaderName, earlyData.String())
}
}
wsConn, resp, err := dialer.Dial(uri.String(), headers)
if err != nil {
reason := err.Error()
@ -166,3 +303,11 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
remoteAddr: conn.RemoteAddr(),
}, nil
}
func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
if c.MaxEarlyData > 0 {
return streamWebsocketWithEarlyDataConn(conn, c)
}
return streamWebsocketConn(conn, c, nil)
}