feat: add a simple task pool
This commit is contained in:
parent
3ff0a21e5d
commit
cfeaaacc69
78
pkg/pool/pool.go
Normal file
78
pkg/pool/pool.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskPool struct {
|
||||||
|
workers int
|
||||||
|
queue chan Task
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
lck sync.Mutex
|
||||||
|
curTaskID int
|
||||||
|
waitMap map[int]chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
||||||
|
return &TaskPool{
|
||||||
|
workers: maxWorkers,
|
||||||
|
queue: make(chan Task, bufferSize),
|
||||||
|
waitMap: make(map[int]chan struct{}),
|
||||||
|
curTaskID: 1, // task id starts from 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TaskPool) Start() {
|
||||||
|
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
|
||||||
|
worker := NewWorker(i, tp.queue, tp)
|
||||||
|
tp.wg.Add(1)
|
||||||
|
go worker.Start(&tp.wg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TaskPool) AddTask(f func()) int {
|
||||||
|
tp.lck.Lock()
|
||||||
|
defer tp.lck.Unlock()
|
||||||
|
|
||||||
|
id := tp.curTaskID
|
||||||
|
tp.curTaskID++
|
||||||
|
|
||||||
|
task := Task{id: id, f: f}
|
||||||
|
tp.queue <- task
|
||||||
|
|
||||||
|
waitChan := make(chan struct{})
|
||||||
|
tp.waitMap[id] = waitChan
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TaskPool) WaitForTask(taskID int) {
|
||||||
|
tp.lck.Lock()
|
||||||
|
waitChan, ok := tp.waitMap[taskID]
|
||||||
|
if !ok {
|
||||||
|
tp.lck.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tp.lck.Unlock()
|
||||||
|
|
||||||
|
<-waitChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TaskPool) MarkTaskComplete(taskID int) {
|
||||||
|
tp.lck.Lock()
|
||||||
|
defer tp.lck.Unlock()
|
||||||
|
|
||||||
|
waitChan, ok := tp.waitMap[taskID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(waitChan)
|
||||||
|
delete(tp.waitMap, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *TaskPool) Stop() {
|
||||||
|
close(tp.queue)
|
||||||
|
tp.wg.Wait()
|
||||||
|
}
|
62
pkg/pool/pool_test.go
Normal file
62
pkg/pool/pool_test.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTaskPool_Stop(t *testing.T) {
|
||||||
|
pool := NewTaskPool(5, 10)
|
||||||
|
pool.Start()
|
||||||
|
|
||||||
|
lck := sync.Mutex{}
|
||||||
|
counter := 0
|
||||||
|
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
f := func(i int) func() {
|
||||||
|
return func() {
|
||||||
|
lck.Lock()
|
||||||
|
t.Log("task", i, "locked")
|
||||||
|
counter += i
|
||||||
|
t.Log("task", i, "unlocked")
|
||||||
|
lck.Unlock()
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(i*100) * time.Millisecond)
|
||||||
|
t.Log("task", i, "finished")
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
pool.AddTask(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Stop()
|
||||||
|
|
||||||
|
if counter != 55 {
|
||||||
|
t.Error("some tasks were not executed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskPool_WaitForTask(t *testing.T) {
|
||||||
|
pool := NewTaskPool(10, 10)
|
||||||
|
pool.Start()
|
||||||
|
|
||||||
|
counter := 0
|
||||||
|
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
f := func(i int) func() {
|
||||||
|
return func() {
|
||||||
|
counter += 1
|
||||||
|
t.Log("task", i, "finished")
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
id := pool.AddTask(f)
|
||||||
|
|
||||||
|
pool.WaitForTask(id)
|
||||||
|
if counter != 1 {
|
||||||
|
t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id)
|
||||||
|
}
|
||||||
|
counter -= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Stop()
|
||||||
|
}
|
6
pkg/pool/task.go
Normal file
6
pkg/pool/task.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
id int
|
||||||
|
f func()
|
||||||
|
}
|
24
pkg/pool/worker.go
Normal file
24
pkg/pool/worker.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Worker struct {
|
||||||
|
id int
|
||||||
|
queue chan Task
|
||||||
|
pool *TaskPool // back reference to the pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWorker(id int, queue chan Task, pool *TaskPool) *Worker {
|
||||||
|
return &Worker{id: id, queue: queue, pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Worker) Start(wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for task := range w.queue {
|
||||||
|
task.f()
|
||||||
|
w.pool.MarkTaskComplete(task.id)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user