79 lines
1.2 KiB
Go
79 lines
1.2 KiB
Go
|
package pool
|
||
|
|
||
|
import (
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
type TaskPool struct {
|
||
|
workers int
|
||
|
queue chan Task
|
||
|
wg sync.WaitGroup
|
||
|
|
||
|
lck sync.Mutex
|
||
|
curTaskID int
|
||
|
waitMap map[int]chan struct{}
|
||
|
}
|
||
|
|
||
|
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
||
|
return &TaskPool{
|
||
|
workers: maxWorkers,
|
||
|
queue: make(chan Task, bufferSize),
|
||
|
waitMap: make(map[int]chan struct{}),
|
||
|
curTaskID: 1, // task id starts from 1
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (tp *TaskPool) Start() {
|
||
|
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
|
||
|
worker := NewWorker(i, tp.queue, tp)
|
||
|
tp.wg.Add(1)
|
||
|
go worker.Start(&tp.wg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (tp *TaskPool) AddTask(f func()) int {
|
||
|
tp.lck.Lock()
|
||
|
defer tp.lck.Unlock()
|
||
|
|
||
|
id := tp.curTaskID
|
||
|
tp.curTaskID++
|
||
|
|
||
|
task := Task{id: id, f: f}
|
||
|
tp.queue <- task
|
||
|
|
||
|
waitChan := make(chan struct{})
|
||
|
tp.waitMap[id] = waitChan
|
||
|
|
||
|
return id
|
||
|
}
|
||
|
|
||
|
func (tp *TaskPool) WaitForTask(taskID int) {
|
||
|
tp.lck.Lock()
|
||
|
waitChan, ok := tp.waitMap[taskID]
|
||
|
if !ok {
|
||
|
tp.lck.Unlock()
|
||
|
return
|
||
|
}
|
||
|
tp.lck.Unlock()
|
||
|
|
||
|
<-waitChan
|
||
|
}
|
||
|
|
||
|
func (tp *TaskPool) MarkTaskComplete(taskID int) {
|
||
|
tp.lck.Lock()
|
||
|
defer tp.lck.Unlock()
|
||
|
|
||
|
waitChan, ok := tp.waitMap[taskID]
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
close(waitChan)
|
||
|
delete(tp.waitMap, taskID)
|
||
|
}
|
||
|
|
||
|
func (tp *TaskPool) Stop() {
|
||
|
close(tp.queue)
|
||
|
tp.wg.Wait()
|
||
|
}
|