diff --git a/common/observable/observable.go b/common/observable/observable.go index 4b216d4f..64bd0a0a 100644 --- a/common/observable/observable.go +++ b/common/observable/observable.go @@ -7,60 +7,58 @@ import ( type Observable struct { iterable Iterable - listener *sync.Map + listener map[Subscription]*Subscriber + mux sync.Mutex done bool - doneLock sync.RWMutex } func (o *Observable) process() { for item := range o.iterable { - o.listener.Range(func(key, value interface{}) bool { - elm := value.(*Subscriber) - elm.Emit(item) - return true - }) + o.mux.Lock() + for _, sub := range o.listener { + sub.Emit(item) + } + o.mux.Unlock() } o.close() } func (o *Observable) close() { - o.doneLock.Lock() - o.done = true - o.doneLock.Unlock() + o.mux.Lock() + defer o.mux.Unlock() - o.listener.Range(func(key, value interface{}) bool { - elm := value.(*Subscriber) - elm.Close() - return true - }) + o.done = true + for _, sub := range o.listener { + sub.Close() + } } func (o *Observable) Subscribe() (Subscription, error) { - o.doneLock.RLock() - done := o.done - o.doneLock.RUnlock() - if done == true { + o.mux.Lock() + defer o.mux.Unlock() + if o.done { return nil, errors.New("Observable is closed") } subscriber := newSubscriber() - o.listener.Store(subscriber.Out(), subscriber) + o.listener[subscriber.Out()] = subscriber return subscriber.Out(), nil } func (o *Observable) UnSubscribe(sub Subscription) { - elm, exist := o.listener.Load(sub) + o.mux.Lock() + defer o.mux.Unlock() + subscriber, exist := o.listener[sub] if !exist { return } - subscriber := elm.(*Subscriber) - o.listener.Delete(subscriber.Out()) + delete(o.listener, sub) subscriber.Close() } func NewObservable(any Iterable) *Observable { observable := &Observable{ iterable: any, - listener: &sync.Map{}, + listener: map[Subscription]*Subscriber{}, } go observable.process() return observable diff --git a/common/observable/observable_test.go b/common/observable/observable_test.go index 41ee272d..d965fa3b 100644 --- a/common/observable/observable_test.go +++ b/common/observable/observable_test.go @@ -5,6 +5,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func iterator(item []interface{}) chan interface{} { @@ -23,16 +25,12 @@ func TestObservable(t *testing.T) { iter := iterator([]interface{}{1, 2, 3, 4, 5}) src := NewObservable(iter) data, err := src.Subscribe() - if err != nil { - t.Error(err) - } + assert.Nil(t, err) count := 0 for range data { count++ } - if count != 5 { - t.Error("Revc number error") - } + assert.Equal(t, count, 5) } func TestObservable_MutilSubscribe(t *testing.T) { @@ -53,23 +51,17 @@ func TestObservable_MutilSubscribe(t *testing.T) { go waitCh(ch1) go waitCh(ch2) wg.Wait() - if count != 10 { - t.Error("Revc number error") - } + assert.Equal(t, count, 10) } func TestObservable_UnSubscribe(t *testing.T) { iter := iterator([]interface{}{1, 2, 3, 4, 5}) src := NewObservable(iter) data, err := src.Subscribe() - if err != nil { - t.Error(err) - } + assert.Nil(t, err) src.UnSubscribe(data) _, open := <-data - if open { - t.Error("Revc number error") - } + assert.False(t, open) } func TestObservable_SubscribeClosedSource(t *testing.T) { @@ -79,9 +71,7 @@ func TestObservable_SubscribeClosedSource(t *testing.T) { <-data _, closed := src.Subscribe() - if closed == nil { - t.Error("Observable should be closed") - } + assert.NotNil(t, closed) } func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { @@ -118,7 +108,5 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) { } wg.Wait() now := runtime.NumGoroutine() - if init != now { - t.Errorf("Goroutine Leak: init %d now %d", init, now) - } + assert.Equal(t, init, now) }