package pool import ( "sync" ) type TaskPool struct { workers int queue chan Task wg sync.WaitGroup lck sync.Mutex curTaskID int waitMap map[int]chan error } type ErrTaskNotFound struct{} func (m *ErrTaskNotFound) Error() string { return "task not found" } func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { return &TaskPool{ workers: maxWorkers, queue: make(chan Task, bufferSize), waitMap: make(map[int]chan error), curTaskID: 1, // task id starts from 1 } } func (tp *TaskPool) Start() { for i := 1; i <= tp.workers; i++ { // worker id starts from 1 worker := NewWorker(i, tp.queue, tp) tp.wg.Add(1) go worker.Start(&tp.wg) } } func (tp *TaskPool) AddTask(f func() error) int { tp.lck.Lock() id := tp.curTaskID tp.curTaskID++ waitChan := make(chan error, 1) tp.waitMap[id] = waitChan tp.lck.Unlock() task := Task{id: id, f: f} tp.queue <- task return id } func (tp *TaskPool) WaitForTask(taskID int) error { tp.lck.Lock() waitChan, ok := tp.waitMap[taskID] if !ok { tp.lck.Unlock() return &ErrTaskNotFound{} } tp.lck.Unlock() ret := <-waitChan close(waitChan) tp.lck.Lock() delete(tp.waitMap, taskID) tp.lck.Unlock() return ret } func (tp *TaskPool) markTaskComplete(taskID int, err error) { tp.lck.Lock() waitChan, ok := tp.waitMap[taskID] if !ok { // should never happen here panic("worker: task destroyed before completion") } tp.lck.Unlock() waitChan <- err } func (tp *TaskPool) Stop() { close(tp.queue) tp.wg.Wait() }