chore: pool.WaitForTask should return typed error
This commit is contained in:
parent
33107bf3ae
commit
edd297ada2
@ -1,7 +1,6 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -15,6 +14,12 @@ type TaskPool struct {
|
||||
waitMap map[int]chan error
|
||||
}
|
||||
|
||||
type ErrTaskNotFound struct{}
|
||||
|
||||
func (m *ErrTaskNotFound) Error() string {
|
||||
return "task not found"
|
||||
}
|
||||
|
||||
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
||||
return &TaskPool{
|
||||
workers: maxWorkers,
|
||||
@ -54,7 +59,7 @@ func (tp *TaskPool) WaitForTask(taskID int) error {
|
||||
waitChan, ok := tp.waitMap[taskID]
|
||||
if !ok {
|
||||
tp.lck.Unlock()
|
||||
return errors.New("task not found")
|
||||
return &ErrTaskNotFound{}
|
||||
}
|
||||
tp.lck.Unlock()
|
||||
|
||||
@ -72,8 +77,8 @@ func (tp *TaskPool) markTaskComplete(taskID int, err error) {
|
||||
tp.lck.Lock()
|
||||
waitChan, ok := tp.waitMap[taskID]
|
||||
if !ok {
|
||||
tp.lck.Unlock()
|
||||
return
|
||||
// should never happen here
|
||||
panic("worker: task destroyed before completion")
|
||||
}
|
||||
tp.lck.Unlock()
|
||||
|
||||
|
@ -51,6 +51,7 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
||||
return func() error {
|
||||
counter += 1
|
||||
t.Log("task", i, "finished")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return errors.New(strconv.Itoa(i))
|
||||
}
|
||||
}(i)
|
||||
@ -69,6 +70,31 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
||||
pool.Stop()
|
||||
}
|
||||
|
||||
func TestTaskPool_DoubleWait(t *testing.T) {
|
||||
pool := NewTaskPool(1, 1)
|
||||
pool.Start()
|
||||
|
||||
f := func() error {
|
||||
t.Log("task invoked")
|
||||
return nil
|
||||
}
|
||||
id := pool.AddTask(f)
|
||||
|
||||
ret := pool.WaitForTask(id)
|
||||
if ret != nil {
|
||||
t.Errorf("task returned error: %v", ret)
|
||||
}
|
||||
|
||||
ret2 := pool.WaitForTask(id)
|
||||
if ret2 == nil {
|
||||
t.Errorf("2nd wait returned nil")
|
||||
} else if !errors.Is(ret2, &ErrTaskNotFound{}) {
|
||||
t.Errorf("2nd wait returned wrong error: %v", ret2)
|
||||
}
|
||||
|
||||
pool.Stop()
|
||||
}
|
||||
|
||||
func TestTaskPool_One(t *testing.T) {
|
||||
pool := NewTaskPool(1, 1)
|
||||
pool.Start()
|
||||
|
Loading…
Reference in New Issue
Block a user