clash/transport/vmess/chunk.go

103 lines
1.8 KiB
Go
Raw Normal View History

2018-09-06 02:53:29 +00:00
package vmess
import (
"encoding/binary"
"errors"
"io"
2019-04-23 15:29:36 +00:00
"github.com/Dreamacro/clash/common/pool"
2018-09-06 02:53:29 +00:00
)
const (
lenSize = 2
chunkSize = 1 << 14 // 2 ** 14 == 16 * 1024
maxSize = 17 * 1024 // 2 + chunkSize + aead.Overhead()
)
type chunkReader struct {
io.Reader
buf []byte
sizeBuf []byte
offset int
}
func newChunkReader(reader io.Reader) *chunkReader {
return &chunkReader{Reader: reader, sizeBuf: make([]byte, lenSize)}
}
func newChunkWriter(writer io.WriteCloser) *chunkWriter {
return &chunkWriter{Writer: writer}
}
func (cr *chunkReader) Read(b []byte) (int, error) {
if cr.buf != nil {
n := copy(b, cr.buf[cr.offset:])
cr.offset += n
if cr.offset == len(cr.buf) {
2020-04-24 16:30:40 +00:00
pool.Put(cr.buf)
2018-09-06 02:53:29 +00:00
cr.buf = nil
}
return n, nil
}
_, err := io.ReadFull(cr.Reader, cr.sizeBuf)
if err != nil {
return 0, err
}
size := int(binary.BigEndian.Uint16(cr.sizeBuf))
if size > maxSize {
2020-08-25 14:19:59 +00:00
return 0, errors.New("buffer is larger than standard")
2018-09-06 02:53:29 +00:00
}
if len(b) >= size {
_, err := io.ReadFull(cr.Reader, b[:size])
if err != nil {
return 0, err
}
return size, nil
}
2020-04-24 16:30:40 +00:00
buf := pool.Get(size)
_, err = io.ReadFull(cr.Reader, buf)
2018-09-06 02:53:29 +00:00
if err != nil {
2020-04-24 16:30:40 +00:00
pool.Put(buf)
2018-09-06 02:53:29 +00:00
return 0, err
}
2020-04-24 16:30:40 +00:00
n := copy(b, buf)
2018-09-06 02:53:29 +00:00
cr.offset = n
2020-04-24 16:30:40 +00:00
cr.buf = buf
2018-09-06 02:53:29 +00:00
return n, nil
}
type chunkWriter struct {
io.Writer
}
func (cw *chunkWriter) Write(b []byte) (n int, err error) {
2020-04-24 16:30:40 +00:00
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)
2018-09-06 02:53:29 +00:00
length := len(b)
for {
if length == 0 {
break
}
readLen := chunkSize
if length < chunkSize {
readLen = length
}
payloadBuf := buf[lenSize : lenSize+chunkSize]
copy(payloadBuf, b[n:n+readLen])
binary.BigEndian.PutUint16(buf[:lenSize], uint16(readLen))
_, err = cw.Writer.Write(buf[:lenSize+readLen])
if err != nil {
break
}
n += readLen
length -= readLen
}
return
}