chore: decrease memory copy in quic sniffer

This commit is contained in:
wwqgtxx 2023-10-19 23:51:37 +08:00
parent 8e637a2ec7
commit ea7e15b447
5 changed files with 41 additions and 35 deletions

View file

@ -10,6 +10,7 @@ const BufferSize = buf.BufferSize
type Buffer = buf.Buffer type Buffer = buf.Buffer
var New = buf.New var New = buf.New
var NewPacket = buf.NewPacket
var NewSize = buf.NewSize var NewSize = buf.NewSize
var With = buf.With var With = buf.With
var As = buf.As var As = buf.As

View file

@ -51,10 +51,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
overrideDest := config.OverrideDest overrideDest := config.OverrideDest
if inWhitelist { if inWhitelist {
var copyBuf = make([]byte, len(packet.Data())) host, err := sniffer.SniffData(packet.Data())
copy(copyBuf, packet.Data())
host, err := sniffer.SniffData(copyBuf)
if err != nil { if err != nil {
continue continue
} }

View file

@ -107,10 +107,7 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
return "", errNotQuic return "", errNotQuic
} }
hdrLen := len(b) - int(buffer.Len()) hdrLen := len(b) - buffer.Len()
origPNBytes := make([]byte, 4)
copy(origPNBytes, b[hdrLen:hdrLen+4])
var salt []byte var salt []byte
if versionNumber == version1 { if versionNumber == version1 {
@ -126,31 +123,40 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
return "", err return "", err
} }
cache := buf.New() cache := buf.NewPacket()
defer cache.Release() defer cache.Release()
mask := cache.Extend(int(block.BlockSize())) mask := cache.Extend(block.BlockSize())
block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16]) block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
b[0] ^= mask[0] & 0xf firstByte := b[0]
for i := range b[hdrLen : hdrLen+4] { // Encrypt/decrypt first byte.
b[hdrLen+i] ^= mask[i+1] if isLongHeader {
// Long header: 4 bits masked
// High 4 bits are not protected.
firstByte ^= mask[0] & 0x0f
} else {
// Short header: 5 bits masked
// High 3 bits are not protected.
firstByte ^= mask[0] & 0x1f
} }
packetNumberLength := b[0]&0x3 + 1 packetNumberLength := int(firstByte&0x3 + 1) // max = 4 (64-bit sequence number)
var packetNumber uint32 extHdrLen := hdrLen + packetNumberLength
{
n, err := buffer.ReadByte() // copy to avoid modify origin data
if err != nil { extHdr := cache.Extend(extHdrLen)
return "", err copy(extHdr, b)
} extHdr[0] = firstByte
packetNumber = uint32(n)
packetNumber := extHdr[hdrLen:extHdrLen]
// Encrypt/decrypt packet number.
for i := range packetNumber {
packetNumber[i] ^= mask[1+i]
} }
if packetNumber != 0 && packetNumber != 1 { if packetNumber[0] != 0 && packetNumber[0] != 1 {
return "", errNotQuicInitial return "", errNotQuicInitial
} }
extHdrLen := hdrLen + int(packetNumberLength)
copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
data := b[extHdrLen : int(packetLen)+hdrLen] data := b[extHdrLen : int(packetLen)+hdrLen]
key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
@ -163,24 +169,20 @@ func (quic QuicSniffer) SniffData(b []byte) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
nonce := cache.Extend(8) // 64-bit sequence number
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
// copy from crypto/tls.aeadAESGCMTLS13
for i, b := range nonce {
iv[4+i] ^= b
}
decrypted, err := aead.Open(b[extHdrLen:extHdrLen], iv, data, b[:extHdrLen])
// We only decrypt once, so we do not need to XOR it back. // We only decrypt once, so we do not need to XOR it back.
//for i, b := range nonce { // https://github.com/quic-go/qtls-go1-20/blob/e132a0e6cb45e20ac0b705454849a11d09ba5a54/cipher_suites.go#L496
// iv[4+i] ^= b for i, b := range packetNumber {
//} iv[len(iv)-len(packetNumber)+i] ^= b
}
dst := cache.Extend(len(data))
decrypted, err := aead.Open(dst[:0], iv, data, extHdr)
if err != nil { if err != nil {
return "", err return "", err
} }
buffer = buf.As(decrypted) buffer = buf.As(decrypted)
cryptoLen := uint(0) cryptoLen := uint(0)
cryptoData := make([]byte, buffer.Len()) cryptoData := cache.Extend(buffer.Len())
for i := 0; !buffer.IsEmpty(); i++ { for i := 0; !buffer.IsEmpty(); i++ {
frameType := byte(0x0) // Default to PADDING frame frameType := byte(0x0) // Default to PADDING frame
for frameType == 0x0 && !buffer.IsEmpty() { for frameType == 0x0 && !buffer.IsEmpty() {

View file

@ -1,6 +1,7 @@
package sniffer package sniffer
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
@ -26,9 +27,11 @@ func TestQuicHeaders(t *testing.T) {
for _, test := range cases { for _, test := range cases {
pkt, err := hex.DecodeString(test.input) pkt, err := hex.DecodeString(test.input)
assert.NoError(t, err) assert.NoError(t, err)
oriPkt := bytes.Clone(pkt)
domain, err := q.SniffData(pkt) domain, err := q.SniffData(pkt)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, test.domain, domain) assert.Equal(t, test.domain, domain)
assert.Equal(t, oriPkt, pkt) // ensure input data not changed
} }
} }
@ -170,6 +173,7 @@ func TestTLSHeaders(t *testing.T) {
} }
for _, test := range cases { for _, test := range cases {
input := bytes.Clone(test.input)
domain, err := SniffTLS(test.input) domain, err := SniffTLS(test.input)
if test.err { if test.err {
if err == nil { if err == nil {
@ -183,5 +187,6 @@ func TestTLSHeaders(t *testing.T) {
t.Error("expect domain ", test.domain, " but got ", domain) t.Error("expect domain ", test.domain, " but got ", domain)
} }
} }
assert.Equal(t, input, test.input)
} }
} }

View file

@ -4,6 +4,7 @@ import "github.com/Dreamacro/clash/constant"
type Sniffer interface { type Sniffer interface {
SupportNetwork() constant.NetWork SupportNetwork() constant.NetWork
// SniffData must not change input bytes
SniffData(bytes []byte) (string, error) SniffData(bytes []byte) (string, error)
Protocol() string Protocol() string
SupportPort(port uint16) bool SupportPort(port uint16) bool