woj-server/pkg/pool/pool.go

87 lines
1.5 KiB
Go
Raw Permalink Normal View History

2024-01-06 01:50:20 +08:00
package pool
import (
"sync"
"sync/atomic"
2024-01-06 01:50:20 +08:00
)
type TaskPool struct {
workers int
queue chan Task
wg sync.WaitGroup
curTaskID atomic.Uint64
waitMap sync.Map
2024-01-06 01:50:20 +08:00
}
type ErrTaskNotFound struct{}
func (m *ErrTaskNotFound) Error() string {
return "task not found"
}
type WaitBuf struct {
Value interface{}
Error error
}
2024-01-06 01:50:20 +08:00
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
tp := &TaskPool{
2024-01-06 01:50:20 +08:00
workers: maxWorkers,
queue: make(chan Task, bufferSize),
waitMap: sync.Map{},
curTaskID: atomic.Uint64{},
2024-01-06 01:50:20 +08:00
}
return tp
2024-01-06 01:50:20 +08:00
}
func (tp *TaskPool) Start() {
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
worker := NewWorker(i, tp.queue, tp)
tp.wg.Add(1)
go worker.Start(&tp.wg)
}
}
func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 {
id := tp.curTaskID.Add(1)
2024-01-06 01:50:20 +08:00
waitChan := make(chan WaitBuf, 1)
tp.waitMap.Store(id, waitChan)
2024-01-06 21:03:30 +08:00
task := Task{id: id, f: f}
tp.queue <- task
2024-01-06 01:50:20 +08:00
return id
}
func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf {
val, ok := tp.waitMap.Load(taskID)
2024-01-06 01:50:20 +08:00
if !ok {
return WaitBuf{nil, &ErrTaskNotFound{}}
2024-01-06 01:50:20 +08:00
}
waitChan := val.(chan WaitBuf)
2024-01-06 01:50:20 +08:00
ret := <-waitChan
close(waitChan)
tp.waitMap.Delete(taskID)
return ret
}
2024-01-06 01:50:20 +08:00
func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) {
val, ok := tp.waitMap.Load(taskID)
2024-01-06 01:50:20 +08:00
if !ok {
// should never happen here
panic("worker: task destroyed before completion")
2024-01-06 01:50:20 +08:00
}
waitChan := val.(chan WaitBuf)
waitChan <- buf
2024-01-06 01:50:20 +08:00
}
func (tp *TaskPool) Stop() {
close(tp.queue)
tp.wg.Wait()
}