feat: add a simple task pool

This commit is contained in:
Paul Pan 2024-01-06 01:50:20 +08:00
parent 3ff0a21e5d
commit cfeaaacc69
Signed by: Paul
GPG Key ID: D639BDF5BA578AF4
4 changed files with 170 additions and 0 deletions

78
pkg/pool/pool.go Normal file
View File

@ -0,0 +1,78 @@
package pool
import (
"sync"
)
type TaskPool struct {
workers int
queue chan Task
wg sync.WaitGroup
lck sync.Mutex
curTaskID int
waitMap map[int]chan struct{}
}
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
return &TaskPool{
workers: maxWorkers,
queue: make(chan Task, bufferSize),
waitMap: make(map[int]chan struct{}),
curTaskID: 1, // task id starts from 1
}
}
func (tp *TaskPool) Start() {
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
worker := NewWorker(i, tp.queue, tp)
tp.wg.Add(1)
go worker.Start(&tp.wg)
}
}
func (tp *TaskPool) AddTask(f func()) int {
tp.lck.Lock()
defer tp.lck.Unlock()
id := tp.curTaskID
tp.curTaskID++
task := Task{id: id, f: f}
tp.queue <- task
waitChan := make(chan struct{})
tp.waitMap[id] = waitChan
return id
}
func (tp *TaskPool) WaitForTask(taskID int) {
tp.lck.Lock()
waitChan, ok := tp.waitMap[taskID]
if !ok {
tp.lck.Unlock()
return
}
tp.lck.Unlock()
<-waitChan
}
func (tp *TaskPool) MarkTaskComplete(taskID int) {
tp.lck.Lock()
defer tp.lck.Unlock()
waitChan, ok := tp.waitMap[taskID]
if !ok {
return
}
close(waitChan)
delete(tp.waitMap, taskID)
}
func (tp *TaskPool) Stop() {
close(tp.queue)
tp.wg.Wait()
}

62
pkg/pool/pool_test.go Normal file
View File

@ -0,0 +1,62 @@
package pool
import (
"sync"
"testing"
"time"
)
func TestTaskPool_Stop(t *testing.T) {
pool := NewTaskPool(5, 10)
pool.Start()
lck := sync.Mutex{}
counter := 0
for i := 1; i <= 10; i++ {
f := func(i int) func() {
return func() {
lck.Lock()
t.Log("task", i, "locked")
counter += i
t.Log("task", i, "unlocked")
lck.Unlock()
time.Sleep(time.Duration(i*100) * time.Millisecond)
t.Log("task", i, "finished")
}
}(i)
pool.AddTask(f)
}
pool.Stop()
if counter != 55 {
t.Error("some tasks were not executed")
}
}
func TestTaskPool_WaitForTask(t *testing.T) {
pool := NewTaskPool(10, 10)
pool.Start()
counter := 0
for i := 1; i <= 10; i++ {
f := func(i int) func() {
return func() {
counter += 1
t.Log("task", i, "finished")
}
}(i)
id := pool.AddTask(f)
pool.WaitForTask(id)
if counter != 1 {
t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id)
}
counter -= 1
}
pool.Stop()
}

6
pkg/pool/task.go Normal file
View File

@ -0,0 +1,6 @@
package pool
type Task struct {
id int
f func()
}

24
pkg/pool/worker.go Normal file
View File

@ -0,0 +1,24 @@
package pool
import (
"sync"
)
type Worker struct {
id int
queue chan Task
pool *TaskPool // back reference to the pool
}
func NewWorker(id int, queue chan Task, pool *TaskPool) *Worker {
return &Worker{id: id, queue: queue, pool: pool}
}
func (w *Worker) Start(wg *sync.WaitGroup) {
defer wg.Done()
for task := range w.queue {
task.f()
w.pool.MarkTaskComplete(task.id)
}
}