2024-01-06 01:50:20 +08:00
|
|
|
package pool
|
|
|
|
|
|
|
|
import (
|
|
|
|
"sync"
|
2024-02-15 12:53:57 +08:00
|
|
|
"sync/atomic"
|
2024-01-06 01:50:20 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
type TaskPool struct {
|
|
|
|
workers int
|
|
|
|
queue chan Task
|
|
|
|
wg sync.WaitGroup
|
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
curTaskID atomic.Uint64
|
|
|
|
waitMap sync.Map
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
|
|
|
|
2024-01-28 22:02:26 +08:00
|
|
|
type ErrTaskNotFound struct{}
|
|
|
|
|
|
|
|
func (m *ErrTaskNotFound) Error() string {
|
|
|
|
return "task not found"
|
|
|
|
}
|
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
type WaitBuf struct {
|
|
|
|
Value interface{}
|
|
|
|
Error error
|
|
|
|
}
|
|
|
|
|
2024-01-06 01:50:20 +08:00
|
|
|
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
2024-02-15 12:53:57 +08:00
|
|
|
tp := &TaskPool{
|
2024-01-06 01:50:20 +08:00
|
|
|
workers: maxWorkers,
|
|
|
|
queue: make(chan Task, bufferSize),
|
2024-02-15 12:53:57 +08:00
|
|
|
waitMap: sync.Map{},
|
|
|
|
curTaskID: atomic.Uint64{},
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
2024-02-15 12:53:57 +08:00
|
|
|
return tp
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func (tp *TaskPool) Start() {
|
|
|
|
for i := 1; i <= tp.workers; i++ { // worker id starts from 1
|
|
|
|
worker := NewWorker(i, tp.queue, tp)
|
|
|
|
tp.wg.Add(1)
|
|
|
|
go worker.Start(&tp.wg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 {
|
|
|
|
id := tp.curTaskID.Add(1)
|
2024-01-06 01:50:20 +08:00
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
waitChan := make(chan WaitBuf, 1)
|
|
|
|
tp.waitMap.Store(id, waitChan)
|
2024-01-06 21:03:30 +08:00
|
|
|
|
|
|
|
task := Task{id: id, f: f}
|
|
|
|
tp.queue <- task
|
|
|
|
|
2024-01-06 01:50:20 +08:00
|
|
|
return id
|
|
|
|
}
|
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf {
|
|
|
|
val, ok := tp.waitMap.Load(taskID)
|
2024-01-06 01:50:20 +08:00
|
|
|
if !ok {
|
2024-02-15 12:53:57 +08:00
|
|
|
return WaitBuf{nil, &ErrTaskNotFound{}}
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
2024-02-15 12:53:57 +08:00
|
|
|
waitChan := val.(chan WaitBuf)
|
2024-01-06 01:50:20 +08:00
|
|
|
|
2024-01-28 18:16:04 +08:00
|
|
|
ret := <-waitChan
|
|
|
|
close(waitChan)
|
2024-02-15 12:53:57 +08:00
|
|
|
tp.waitMap.Delete(taskID)
|
2024-01-28 18:16:04 +08:00
|
|
|
|
|
|
|
return ret
|
|
|
|
}
|
2024-01-06 01:50:20 +08:00
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) {
|
|
|
|
val, ok := tp.waitMap.Load(taskID)
|
2024-01-06 01:50:20 +08:00
|
|
|
if !ok {
|
2024-01-28 22:02:26 +08:00
|
|
|
// should never happen here
|
|
|
|
panic("worker: task destroyed before completion")
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
|
|
|
|
2024-02-15 12:53:57 +08:00
|
|
|
waitChan := val.(chan WaitBuf)
|
|
|
|
waitChan <- buf
|
2024-01-06 01:50:20 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func (tp *TaskPool) Stop() {
|
|
|
|
close(tp.queue)
|
|
|
|
tp.wg.Wait()
|
|
|
|
}
|