woj-server/pkg/pool/pool.go

87 lines
1.4 KiB
Go
Raw Permalink Normal View History

2024-01-06 01:50:20 +08:00
package pool
import (
"errors"
2024-01-06 01:50:20 +08:00
"sync"
)
type TaskPool struct {
workers int
queue chan Task
wg sync.WaitGroup
lck sync.Mutex
curTaskID int
waitMap map[int]chan error
2024-01-06 01:50:20 +08:00
}
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
return &TaskPool{
workers: maxWorkers,
queue: make(chan Task, bufferSize),
waitMap: make(map[int]chan error),
2024-01-06 01:50:20 +08:00
curTaskID: 1, // task id starts from 1
}
}
func (tp *TaskPool) Start() {
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
worker := NewWorker(i, tp.queue, tp)
tp.wg.Add(1)
go worker.Start(&tp.wg)
}
}
func (tp *TaskPool) AddTask(f func() error) int {
2024-01-06 01:50:20 +08:00
tp.lck.Lock()
id := tp.curTaskID
tp.curTaskID++
waitChan := make(chan error, 1)
2024-01-06 01:50:20 +08:00
tp.waitMap[id] = waitChan
2024-01-06 21:03:30 +08:00
tp.lck.Unlock()
task := Task{id: id, f: f}
tp.queue <- task
2024-01-06 01:50:20 +08:00
return id
}
func (tp *TaskPool) WaitForTask(taskID int) error {
2024-01-06 01:50:20 +08:00
tp.lck.Lock()
waitChan, ok := tp.waitMap[taskID]
if !ok {
tp.lck.Unlock()
return errors.New("task not found")
2024-01-06 01:50:20 +08:00
}
tp.lck.Unlock()
ret := <-waitChan
close(waitChan)
2024-01-06 01:50:20 +08:00
tp.lck.Lock()
delete(tp.waitMap, taskID)
tp.lck.Unlock()
return ret
}
2024-01-06 01:50:20 +08:00
func (tp *TaskPool) markTaskComplete(taskID int, err error) {
tp.lck.Lock()
2024-01-06 01:50:20 +08:00
waitChan, ok := tp.waitMap[taskID]
if !ok {
tp.lck.Unlock()
2024-01-06 01:50:20 +08:00
return
}
tp.lck.Unlock()
2024-01-06 01:50:20 +08:00
waitChan <- err
2024-01-06 01:50:20 +08:00
}
func (tp *TaskPool) Stop() {
close(tp.queue)
tp.wg.Wait()
}