From cfeaaacc69acbecdc977260be32ce786be10c27f Mon Sep 17 00:00:00 2001 From: Paul Pan Date: Sat, 6 Jan 2024 01:50:20 +0800 Subject: [PATCH] feat: add a simple task pool --- pkg/pool/pool.go | 78 +++++++++++++++++++++++++++++++++++++++++++ pkg/pool/pool_test.go | 62 ++++++++++++++++++++++++++++++++++ pkg/pool/task.go | 6 ++++ pkg/pool/worker.go | 24 +++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 pkg/pool/pool.go create mode 100644 pkg/pool/pool_test.go create mode 100644 pkg/pool/task.go create mode 100644 pkg/pool/worker.go diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go new file mode 100644 index 0000000..45e1df6 --- /dev/null +++ b/pkg/pool/pool.go @@ -0,0 +1,78 @@ +package pool + +import ( + "sync" +) + +type TaskPool struct { + workers int + queue chan Task + wg sync.WaitGroup + + lck sync.Mutex + curTaskID int + waitMap map[int]chan struct{} +} + +func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { + return &TaskPool{ + workers: maxWorkers, + queue: make(chan Task, bufferSize), + waitMap: make(map[int]chan struct{}), + curTaskID: 1, // task id starts from 1 + } +} + +func (tp *TaskPool) Start() { + for i := 1; i <= tp.workers; i++ { // worker id starts from 1 + worker := NewWorker(i, tp.queue, tp) + tp.wg.Add(1) + go worker.Start(&tp.wg) + } +} + +func (tp *TaskPool) AddTask(f func()) int { + tp.lck.Lock() + defer tp.lck.Unlock() + + id := tp.curTaskID + tp.curTaskID++ + + task := Task{id: id, f: f} + tp.queue <- task + + waitChan := make(chan struct{}) + tp.waitMap[id] = waitChan + + return id +} + +func (tp *TaskPool) WaitForTask(taskID int) { + tp.lck.Lock() + waitChan, ok := tp.waitMap[taskID] + if !ok { + tp.lck.Unlock() + return + } + tp.lck.Unlock() + + <-waitChan +} + +func (tp *TaskPool) MarkTaskComplete(taskID int) { + tp.lck.Lock() + defer tp.lck.Unlock() + + waitChan, ok := tp.waitMap[taskID] + if !ok { + return + } + + close(waitChan) + delete(tp.waitMap, taskID) +} + +func (tp *TaskPool) Stop() { + close(tp.queue) + tp.wg.Wait() +} diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go new file mode 100644 index 0000000..fe21f03 --- /dev/null +++ b/pkg/pool/pool_test.go @@ -0,0 +1,62 @@ +package pool + +import ( + "sync" + "testing" + "time" +) + +func TestTaskPool_Stop(t *testing.T) { + pool := NewTaskPool(5, 10) + pool.Start() + + lck := sync.Mutex{} + counter := 0 + + for i := 1; i <= 10; i++ { + f := func(i int) func() { + return func() { + lck.Lock() + t.Log("task", i, "locked") + counter += i + t.Log("task", i, "unlocked") + lck.Unlock() + + time.Sleep(time.Duration(i*100) * time.Millisecond) + t.Log("task", i, "finished") + } + }(i) + pool.AddTask(f) + } + + pool.Stop() + + if counter != 55 { + t.Error("some tasks were not executed") + } +} + +func TestTaskPool_WaitForTask(t *testing.T) { + pool := NewTaskPool(10, 10) + pool.Start() + + counter := 0 + + for i := 1; i <= 10; i++ { + f := func(i int) func() { + return func() { + counter += 1 + t.Log("task", i, "finished") + } + }(i) + id := pool.AddTask(f) + + pool.WaitForTask(id) + if counter != 1 { + t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) + } + counter -= 1 + } + + pool.Stop() +} diff --git a/pkg/pool/task.go b/pkg/pool/task.go new file mode 100644 index 0000000..d3f8b7c --- /dev/null +++ b/pkg/pool/task.go @@ -0,0 +1,6 @@ +package pool + +type Task struct { + id int + f func() +} diff --git a/pkg/pool/worker.go b/pkg/pool/worker.go new file mode 100644 index 0000000..de5f55a --- /dev/null +++ b/pkg/pool/worker.go @@ -0,0 +1,24 @@ +package pool + +import ( + "sync" +) + +type Worker struct { + id int + queue chan Task + pool *TaskPool // back reference to the pool +} + +func NewWorker(id int, queue chan Task, pool *TaskPool) *Worker { + return &Worker{id: id, queue: queue, pool: pool} +} + +func (w *Worker) Start(wg *sync.WaitGroup) { + defer wg.Done() + + for task := range w.queue { + task.f() + w.pool.MarkTaskComplete(task.id) + } +}