chore: pool.WaitForTask should return typed error
This commit is contained in:
parent
33107bf3ae
commit
edd297ada2
@ -1,7 +1,6 @@
|
|||||||
package pool
|
package pool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,6 +14,12 @@ type TaskPool struct {
|
|||||||
waitMap map[int]chan error
|
waitMap map[int]chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrTaskNotFound struct{}
|
||||||
|
|
||||||
|
func (m *ErrTaskNotFound) Error() string {
|
||||||
|
return "task not found"
|
||||||
|
}
|
||||||
|
|
||||||
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
||||||
return &TaskPool{
|
return &TaskPool{
|
||||||
workers: maxWorkers,
|
workers: maxWorkers,
|
||||||
@ -54,7 +59,7 @@ func (tp *TaskPool) WaitForTask(taskID int) error {
|
|||||||
waitChan, ok := tp.waitMap[taskID]
|
waitChan, ok := tp.waitMap[taskID]
|
||||||
if !ok {
|
if !ok {
|
||||||
tp.lck.Unlock()
|
tp.lck.Unlock()
|
||||||
return errors.New("task not found")
|
return &ErrTaskNotFound{}
|
||||||
}
|
}
|
||||||
tp.lck.Unlock()
|
tp.lck.Unlock()
|
||||||
|
|
||||||
@ -72,8 +77,8 @@ func (tp *TaskPool) markTaskComplete(taskID int, err error) {
|
|||||||
tp.lck.Lock()
|
tp.lck.Lock()
|
||||||
waitChan, ok := tp.waitMap[taskID]
|
waitChan, ok := tp.waitMap[taskID]
|
||||||
if !ok {
|
if !ok {
|
||||||
tp.lck.Unlock()
|
// should never happen here
|
||||||
return
|
panic("worker: task destroyed before completion")
|
||||||
}
|
}
|
||||||
tp.lck.Unlock()
|
tp.lck.Unlock()
|
||||||
|
|
||||||
|
@ -51,6 +51,7 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
|||||||
return func() error {
|
return func() error {
|
||||||
counter += 1
|
counter += 1
|
||||||
t.Log("task", i, "finished")
|
t.Log("task", i, "finished")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
return errors.New(strconv.Itoa(i))
|
return errors.New(strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
@ -69,6 +70,31 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
|||||||
pool.Stop()
|
pool.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTaskPool_DoubleWait(t *testing.T) {
|
||||||
|
pool := NewTaskPool(1, 1)
|
||||||
|
pool.Start()
|
||||||
|
|
||||||
|
f := func() error {
|
||||||
|
t.Log("task invoked")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
id := pool.AddTask(f)
|
||||||
|
|
||||||
|
ret := pool.WaitForTask(id)
|
||||||
|
if ret != nil {
|
||||||
|
t.Errorf("task returned error: %v", ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret2 := pool.WaitForTask(id)
|
||||||
|
if ret2 == nil {
|
||||||
|
t.Errorf("2nd wait returned nil")
|
||||||
|
} else if !errors.Is(ret2, &ErrTaskNotFound{}) {
|
||||||
|
t.Errorf("2nd wait returned wrong error: %v", ret2)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
func TestTaskPool_One(t *testing.T) {
|
func TestTaskPool_One(t *testing.T) {
|
||||||
pool := NewTaskPool(1, 1)
|
pool := NewTaskPool(1, 1)
|
||||||
pool.Start()
|
pool.Start()
|
||||||
|
Loading…
Reference in New Issue
Block a user