package pool import ( "sync" "sync/atomic" ) type TaskPool struct { workers int queue chan Task wg sync.WaitGroup curTaskID atomic.Uint64 waitMap sync.Map } type ErrTaskNotFound struct{} func (m *ErrTaskNotFound) Error() string { return "task not found" } type WaitBuf struct { Value interface{} Error error } func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { tp := &TaskPool{ workers: maxWorkers, queue: make(chan Task, bufferSize), waitMap: sync.Map{}, curTaskID: atomic.Uint64{}, } return tp } func (tp *TaskPool) Start() { for i := 1; i <= tp.workers; i++ { // worker id starts from 1 worker := NewWorker(i, tp.queue, tp) tp.wg.Add(1) go worker.Start(&tp.wg) } } func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 { id := tp.curTaskID.Add(1) waitChan := make(chan WaitBuf, 1) tp.waitMap.Store(id, waitChan) task := Task{id: id, f: f} tp.queue <- task return id } func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf { val, ok := tp.waitMap.Load(taskID) if !ok { return WaitBuf{nil, &ErrTaskNotFound{}} } waitChan := val.(chan WaitBuf) ret := <-waitChan close(waitChan) tp.waitMap.Delete(taskID) return ret } func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) { val, ok := tp.waitMap.Load(taskID) if !ok { // should never happen here panic("worker: task destroyed before completion") } waitChan := val.(chan WaitBuf) waitChan <- buf } func (tp *TaskPool) Stop() { close(tp.queue) tp.wg.Wait() }