package pool import ( "sync" ) type TaskPool struct { workers int queue chan Task wg sync.WaitGroup lck sync.Mutex curTaskID int waitMap map[int]chan struct{} } func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { return &TaskPool{ workers: maxWorkers, queue: make(chan Task, bufferSize), waitMap: make(map[int]chan struct{}), curTaskID: 1, // task id starts from 1 } } func (tp *TaskPool) Start() { for i := 1; i <= tp.workers; i++ { // worker id starts from 1 worker := NewWorker(i, tp.queue, tp) tp.wg.Add(1) go worker.Start(&tp.wg) } } func (tp *TaskPool) AddTask(f func()) int { tp.lck.Lock() id := tp.curTaskID tp.curTaskID++ waitChan := make(chan struct{}) tp.waitMap[id] = waitChan tp.lck.Unlock() task := Task{id: id, f: f} tp.queue <- task return id } func (tp *TaskPool) WaitForTask(taskID int) { tp.lck.Lock() waitChan, ok := tp.waitMap[taskID] if !ok { tp.lck.Unlock() return } tp.lck.Unlock() <-waitChan } func (tp *TaskPool) markTaskComplete(taskID int) { tp.lck.Lock() defer tp.lck.Unlock() waitChan, ok := tp.waitMap[taskID] if !ok { return } close(waitChan) delete(tp.waitMap, taskID) } func (tp *TaskPool) Stop() { close(tp.queue) tp.wg.Wait() }