diff --git a/tunnel/manager.go b/tunnel/manager.go index 50c2e116..aef8b617 100644 --- a/tunnel/manager.go +++ b/tunnel/manager.go @@ -2,23 +2,19 @@ package tunnel import ( "sync" + "sync/atomic" "time" ) var DefaultManager *Manager func init() { - DefaultManager = &Manager{ - upload: make(chan int64), - download: make(chan int64), - } + DefaultManager = &Manager{} DefaultManager.handle() } type Manager struct { connections sync.Map - upload chan int64 - download chan int64 uploadTemp int64 downloadTemp int64 uploadBlip int64 @@ -35,16 +31,18 @@ func (m *Manager) Leave(c tracker) { m.connections.Delete(c.ID()) } -func (m *Manager) Upload() chan<- int64 { - return m.upload +func (m *Manager) PushUploaded(size int64) { + atomic.AddInt64(&m.uploadTemp, size) + atomic.AddInt64(&m.uploadTotal, size) } -func (m *Manager) Download() chan<- int64 { - return m.download +func (m *Manager) PushDownloaded(size int64) { + atomic.AddInt64(&m.downloadTemp, size) + atomic.AddInt64(&m.downloadTotal, size) } func (m *Manager) Now() (up int64, down int64) { - return m.uploadBlip, m.downloadBlip + return atomic.LoadInt64(&m.uploadBlip), atomic.LoadInt64(&m.downloadBlip) } func (m *Manager) Snapshot() *Snapshot { @@ -55,8 +53,8 @@ func (m *Manager) Snapshot() *Snapshot { }) return &Snapshot{ - UploadTotal: m.uploadTotal, - DownloadTotal: m.downloadTotal, + UploadTotal: atomic.LoadInt64(&m.uploadTotal), + DownloadTotal: atomic.LoadInt64(&m.downloadTotal), Connections: connections, } } @@ -71,21 +69,18 @@ func (m *Manager) ResetStatistic() { } func (m *Manager) handle() { - go m.handleCh(m.upload, &m.uploadTemp, &m.uploadBlip, &m.uploadTotal) - go m.handleCh(m.download, &m.downloadTemp, &m.downloadBlip, &m.downloadTotal) + go m.handleCh(&m.uploadTemp, &m.uploadBlip) + go m.handleCh(&m.downloadTemp, &m.downloadBlip) } -func (m *Manager) handleCh(ch <-chan int64, temp *int64, blip *int64, total *int64) { +func (m *Manager) handleCh(temp *int64, blip *int64) { ticker := time.NewTicker(time.Second) + for { - select { - case n := <-ch: - *temp += n - *total += n - case <-ticker.C: - *blip = *temp - *temp = 0 - } + <-ticker.C + + atomic.StoreInt64(blip, atomic.LoadInt64(temp)) + atomic.StoreInt64(temp, 0) } } diff --git a/tunnel/tracker.go b/tunnel/tracker.go index 142b7110..e39caec7 100644 --- a/tunnel/tracker.go +++ b/tunnel/tracker.go @@ -37,7 +37,7 @@ func (tt *tcpTracker) ID() string { func (tt *tcpTracker) Read(b []byte) (int, error) { n, err := tt.Conn.Read(b) download := int64(n) - tt.manager.Download() <- download + tt.manager.PushDownloaded(download) tt.DownloadTotal += download return n, err } @@ -45,7 +45,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) { func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) - tt.manager.Upload() <- upload + tt.manager.PushUploaded(upload) tt.UploadTotal += upload return n, err } @@ -92,7 +92,7 @@ func (ut *udpTracker) ID() string { func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) { n, addr, err := ut.PacketConn.ReadFrom(b) download := int64(n) - ut.manager.Download() <- download + ut.manager.PushDownloaded(download) ut.DownloadTotal += download return n, addr, err } @@ -100,7 +100,7 @@ func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) { func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) { n, err := ut.PacketConn.WriteTo(b, addr) upload := int64(n) - ut.manager.Upload() <- upload + ut.manager.PushUploaded(upload) ut.UploadTotal += upload return n, err }