package pool import ( "errors" "sync" "testing" "time" ) func TestTaskPool_Stop(t *testing.T) { pool := NewTaskPool(5, 10) pool.Start() lck := sync.Mutex{} counter := 0 for i := 1; i <= 10; i++ { f := func(i int) func() (interface{}, error) { return func() (interface{}, error) { lck.Lock() t.Log("task", i, "locked") counter += i t.Log("task", i, "unlocked") lck.Unlock() time.Sleep(time.Duration(i*100) * time.Millisecond) t.Log("task", i, "finished") return nil, nil } }(i) pool.AddTask(f) } pool.Stop() if counter != 55 { t.Error("some tasks were not executed") } } func TestTaskPool_WaitForTask(t *testing.T) { pool := NewTaskPool(10, 10) pool.Start() counter := 0 for i := 1; i <= 10; i++ { f := func(i int) func() (interface{}, error) { return func() (interface{}, error) { counter += 1 t.Log("task", i, "finished") time.Sleep(100 * time.Millisecond) return i, nil } }(i) id := pool.AddTask(f) ret := pool.WaitForTask(id) if counter != 1 { t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) } if ret.Error != nil { t.Errorf("Return error: %v, task %d", ret.Error, id) } if ret.Value.(int) != i { t.Errorf("Return value mismatch: expected %d, got %v, task %d", i, ret, id) } counter -= 1 } pool.Stop() } func TestTaskPool_DoubleWait(t *testing.T) { pool := NewTaskPool(1, 1) pool.Start() f := func() (interface{}, error) { t.Log("task invoked") return nil, nil } id := pool.AddTask(f) ret := pool.WaitForTask(id) if ret.Error != nil { t.Errorf("task returned error: %v", ret) } ret2 := pool.WaitForTask(id) if ret2.Error == nil { t.Errorf("2nd wait returned nil") } else if !errors.Is(ret2.Error, &ErrTaskNotFound{}) { t.Errorf("2nd wait returned wrong error: %v", ret2) } pool.Stop() } func TestTaskPool_One(t *testing.T) { pool := NewTaskPool(1, 1) pool.Start() lck := sync.Mutex{} counter := 0 ids := make([]uint64, 0) for i := 1; i <= 10; i++ { f := func(i int) func() (interface{}, error) { return func() (interface{}, error) { lck.Lock() t.Log("task", i, "locked") counter += i t.Log("task", i, "unlocked") lck.Unlock() time.Sleep(time.Duration(i*10) * time.Millisecond) t.Log("task", i, "finished") return nil, nil } }(i) id := pool.AddTask(f) ids = append(ids, id) } for _, id := range ids { _ = pool.WaitForTask(id) } if counter != 55 { t.Error("some tasks were not executed") } pool.Stop() }