chore: pool.WaitForTask should return typed error

This commit is contained in:
Paul Pan 2024-01-28 22:02:26 +08:00
parent 33107bf3ae
commit edd297ada2
Signed by: Paul
GPG Key ID: D639BDF5BA578AF4
2 changed files with 35 additions and 4 deletions

View File

@ -1,7 +1,6 @@
package pool package pool
import ( import (
"errors"
"sync" "sync"
) )
@ -15,6 +14,12 @@ type TaskPool struct {
waitMap map[int]chan error waitMap map[int]chan error
} }
type ErrTaskNotFound struct{}
func (m *ErrTaskNotFound) Error() string {
return "task not found"
}
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
return &TaskPool{ return &TaskPool{
workers: maxWorkers, workers: maxWorkers,
@ -54,7 +59,7 @@ func (tp *TaskPool) WaitForTask(taskID int) error {
waitChan, ok := tp.waitMap[taskID] waitChan, ok := tp.waitMap[taskID]
if !ok { if !ok {
tp.lck.Unlock() tp.lck.Unlock()
return errors.New("task not found") return &ErrTaskNotFound{}
} }
tp.lck.Unlock() tp.lck.Unlock()
@ -72,8 +77,8 @@ func (tp *TaskPool) markTaskComplete(taskID int, err error) {
tp.lck.Lock() tp.lck.Lock()
waitChan, ok := tp.waitMap[taskID] waitChan, ok := tp.waitMap[taskID]
if !ok { if !ok {
tp.lck.Unlock() // should never happen here
return panic("worker: task destroyed before completion")
} }
tp.lck.Unlock() tp.lck.Unlock()

View File

@ -51,6 +51,7 @@ func TestTaskPool_WaitForTask(t *testing.T) {
return func() error { return func() error {
counter += 1 counter += 1
t.Log("task", i, "finished") t.Log("task", i, "finished")
time.Sleep(100 * time.Millisecond)
return errors.New(strconv.Itoa(i)) return errors.New(strconv.Itoa(i))
} }
}(i) }(i)
@ -69,6 +70,31 @@ func TestTaskPool_WaitForTask(t *testing.T) {
pool.Stop() pool.Stop()
} }
func TestTaskPool_DoubleWait(t *testing.T) {
pool := NewTaskPool(1, 1)
pool.Start()
f := func() error {
t.Log("task invoked")
return nil
}
id := pool.AddTask(f)
ret := pool.WaitForTask(id)
if ret != nil {
t.Errorf("task returned error: %v", ret)
}
ret2 := pool.WaitForTask(id)
if ret2 == nil {
t.Errorf("2nd wait returned nil")
} else if !errors.Is(ret2, &ErrTaskNotFound{}) {
t.Errorf("2nd wait returned wrong error: %v", ret2)
}
pool.Stop()
}
func TestTaskPool_One(t *testing.T) { func TestTaskPool_One(t *testing.T) {
pool := NewTaskPool(1, 1) pool := NewTaskPool(1, 1)
pool.Start() pool.Start()