From edd297ada29a1de9761e2ed1857dd2b6baebd9b7 Mon Sep 17 00:00:00 2001 From: Paul Pan Date: Sun, 28 Jan 2024 22:02:26 +0800 Subject: [PATCH] chore: pool.WaitForTask should return typed error --- pkg/pool/pool.go | 13 +++++++++---- pkg/pool/pool_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index 4106fdb..4bc1fcd 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -1,7 +1,6 @@ package pool import ( - "errors" "sync" ) @@ -15,6 +14,12 @@ type TaskPool struct { waitMap map[int]chan error } +type ErrTaskNotFound struct{} + +func (m *ErrTaskNotFound) Error() string { + return "task not found" +} + func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { return &TaskPool{ workers: maxWorkers, @@ -54,7 +59,7 @@ func (tp *TaskPool) WaitForTask(taskID int) error { waitChan, ok := tp.waitMap[taskID] if !ok { tp.lck.Unlock() - return errors.New("task not found") + return &ErrTaskNotFound{} } tp.lck.Unlock() @@ -72,8 +77,8 @@ func (tp *TaskPool) markTaskComplete(taskID int, err error) { tp.lck.Lock() waitChan, ok := tp.waitMap[taskID] if !ok { - tp.lck.Unlock() - return + // should never happen here + panic("worker: task destroyed before completion") } tp.lck.Unlock() diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go index a06389e..fdbfa00 100644 --- a/pkg/pool/pool_test.go +++ b/pkg/pool/pool_test.go @@ -51,6 +51,7 @@ func TestTaskPool_WaitForTask(t *testing.T) { return func() error { counter += 1 t.Log("task", i, "finished") + time.Sleep(100 * time.Millisecond) return errors.New(strconv.Itoa(i)) } }(i) @@ -69,6 +70,31 @@ func TestTaskPool_WaitForTask(t *testing.T) { pool.Stop() } +func TestTaskPool_DoubleWait(t *testing.T) { + pool := NewTaskPool(1, 1) + pool.Start() + + f := func() error { + t.Log("task invoked") + return nil + } + id := pool.AddTask(f) + + ret := pool.WaitForTask(id) + if ret != nil { + t.Errorf("task returned error: %v", ret) + } + + ret2 := pool.WaitForTask(id) + if ret2 == nil { + t.Errorf("2nd wait returned nil") + } else if !errors.Is(ret2, &ErrTaskNotFound{}) { + t.Errorf("2nd wait returned wrong error: %v", ret2) + } + + pool.Stop() +} + func TestTaskPool_One(t *testing.T) { pool := NewTaskPool(1, 1) pool.Start()