feat: capture runtime status from cgroups

pkg/pool: task is now available to return interface{} as result
pkg/pool: use atomic instead of mutex
service/runner: ContainerRun will return metrics
This commit is contained in:
Paul Pan 2024-02-15 12:53:57 +08:00
parent 95e861fe43
commit 6956fe4ee1
Signed by: Paul
GPG Key ID: D639BDF5BA578AF4
10 changed files with 232 additions and 143 deletions

View File

@ -85,7 +85,8 @@ func (s *service) Compile(meta *JudgeMeta) (*JudgeStatus, e.Status) {
} }
id := s.ContainerRunPool(args) id := s.ContainerRunPool(args)
return s.pool.WaitForTask(id) ret := s.pool.WaitForTask(id)
return ret.Error
}). }).
Done() Done()

View File

@ -22,8 +22,8 @@ var (
) )
type ConfigRuntime struct { type ConfigRuntime struct {
TimeLimit int `json:"TimeLimit"` TimeLimit int `json:"TimeLimit"` // in ms
MemoryLimit int `json:"MemoryLimit"` MemoryLimit int `json:"MemoryLimit"` // in mb
NProcLimit int `json:"NProcLimit"` NProcLimit int `json:"NProcLimit"`
} }

View File

@ -2,12 +2,16 @@ package runner
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"git.0x7f.app/WOJ/woj-server/pkg/file" "git.0x7f.app/WOJ/woj-server/pkg/file"
"git.0x7f.app/WOJ/woj-server/pkg/utils" "git.0x7f.app/WOJ/woj-server/pkg/utils"
cgv1 "github.com/containerd/cgroups/v3/cgroup1/stats"
cgv2 "github.com/containerd/cgroups/v3/cgroup2/stats"
"github.com/containerd/containerd" "github.com/containerd/containerd"
"github.com/containerd/containerd/cio" "github.com/containerd/containerd/cio"
"github.com/containerd/containerd/oci" "github.com/containerd/containerd/oci"
"github.com/containerd/typeurl/v2"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"go.uber.org/zap" "go.uber.org/zap"
"io" "io"
@ -49,7 +53,7 @@ type RunArgs struct {
IO IOArgs IO IOArgs
} }
func (s *service) ContainerRun(arg *RunArgs) error { func (s *service) ContainerRun(arg *RunArgs) (RuntimeStatus, error) {
identifier := fmt.Sprintf("%d", s.container.count.Add(1)) identifier := fmt.Sprintf("%d", s.container.count.Add(1))
// prepare args // prepare args
@ -76,11 +80,13 @@ func (s *service) ContainerRun(arg *RunArgs) error {
image, err := s.container.client.GetImage(s.container.ctx, arg.Runtime.Image) image, err := s.container.client.GetImage(s.container.ctx, arg.Runtime.Image)
// TODO: we could cache the image struct // TODO: we could cache the image struct
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
// create container // create container
// TODO: new container is taking too long, could we cache the container struct?
container, err := s.container.client.NewContainer(s.container.ctx, "task-"+identifier, container, err := s.container.client.NewContainer(s.container.ctx, "task-"+identifier,
// TODO: should we use RO snapshot?
containerd.WithNewSnapshot("snapshot-"+identifier, image), containerd.WithNewSnapshot("snapshot-"+identifier, image),
containerd.WithNewSpec( containerd.WithNewSpec(
oci.WithImageConfig(image), oci.WithImageConfig(image),
@ -92,7 +98,7 @@ func (s *service) ContainerRun(arg *RunArgs) error {
), ),
) )
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
defer func(container containerd.Container, ctx context.Context, opts ...containerd.DeleteOpts) { defer func(container containerd.Container, ctx context.Context, opts ...containerd.DeleteOpts) {
_ = container.Delete(ctx, opts...) _ = container.Delete(ctx, opts...)
@ -101,7 +107,7 @@ func (s *service) ContainerRun(arg *RunArgs) error {
// create task // create task
task, err := container.NewTask(s.container.ctx, cio.NewCreator(cio.WithStreams(nil, writer, writer))) task, err := container.NewTask(s.container.ctx, cio.NewCreator(cio.WithStreams(nil, writer, writer)))
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
defer func(task containerd.Task, ctx context.Context, opts ...containerd.ProcessDeleteOpts) { defer func(task containerd.Task, ctx context.Context, opts ...containerd.ProcessDeleteOpts) {
_, _ = task.Delete(ctx, opts...) _, _ = task.Delete(ctx, opts...)
@ -112,13 +118,13 @@ func (s *service) ContainerRun(arg *RunArgs) error {
defer cancel() defer cancel()
exitStatusC, err := task.Wait(ctx2) exitStatusC, err := task.Wait(ctx2)
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
// start // start
err = task.Start(s.container.ctx) err = task.Start(s.container.ctx)
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
// kill on timeout // kill on timeout
@ -130,13 +136,47 @@ func (s *service) ContainerRun(arg *RunArgs) error {
s.log.Debug("container timeout", zap.String("identifier", identifier)) s.log.Debug("container timeout", zap.String("identifier", identifier))
err := task.Kill(s.container.ctx, syscall.SIGKILL) err := task.Kill(s.container.ctx, syscall.SIGKILL)
if err != nil { if err != nil {
return err return RuntimeStatus{}, err
} }
} }
return nil // get metrics
metric, err := task.Metrics(s.container.ctx)
if err != nil {
return RuntimeStatus{}, err
}
// modified from github.com/containerd/containerd/cmd/ctr/commands/tasks/metrics.go
var data interface{}
switch {
case typeurl.Is(metric.Data, (*cgv1.Metrics)(nil)):
data = &cgv1.Metrics{}
case typeurl.Is(metric.Data, (*cgv2.Metrics)(nil)):
data = &cgv2.Metrics{}
default:
return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.Metrics")
}
if err := typeurl.UnmarshalTo(metric.Data, data); err != nil {
return RuntimeStatus{}, err
}
runtime := RuntimeStatus{}
switch v := data.(type) {
case *cgv1.Metrics:
runtime.CpuTime = int(v.CPU.Usage.Total / 1000000) // nanoseconds to milliseconds
runtime.Memory = int(v.Memory.Usage.Max / 1024) // bytes to kilobytes
runtime.RealTime = runtime.CpuTime
case *cgv2.Metrics:
runtime.CpuTime = int(v.CPU.UsageUsec / 1000) // microseconds to milliseconds
runtime.Memory = int(v.Memory.MaxUsage / 1024) // bytes to kilobytes
runtime.RealTime = runtime.CpuTime
default:
return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.{v1/v2}.Metrics")
}
return runtime, nil
} }
func (s *service) ContainerRunPool(arg *RunArgs) int { func (s *service) ContainerRunPool(arg *RunArgs) uint64 {
return s.pool.AddTask(func() error { return s.ContainerRun(arg) }) return s.pool.AddTask(func() (interface{}, error) { return s.ContainerRun(arg) })
} }

View File

@ -80,10 +80,10 @@ func (s *service) PrebuildProblem(meta *JudgeMeta, config *Config, force bool) e
} }
id := s.ContainerRunPool(args) id := s.ContainerRunPool(args)
err := s.pool.WaitForTask(id) ret := s.pool.WaitForTask(id)
if err != nil { if ret.Error != nil {
s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Run.Version)) s.log.Warn("[new] prebuild problem failed", zap.Any("ret", ret), zap.Uint("version", meta.Run.Version))
return e.RunnerProblemPrebuildFailed return e.RunnerProblemPrebuildFailed
} }

View File

@ -13,6 +13,13 @@ import (
"time" "time"
) )
type ProblemRunResult struct {
QueueId uint64
Status RuntimeStatus
}
type ProblemRunResults map[int]*ProblemRunResult
func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string { func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string {
var args []string var args []string
@ -41,7 +48,7 @@ func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string {
return strings.Join(args, " ") return strings.Join(args, " ")
} }
func (s *service) ProblemRun(meta *JudgeMeta) { func (s *service) ProblemRun(meta *JudgeMeta) ProblemRunResults {
workDir := filepath.Join(UserDir, meta.Run.User) workDir := filepath.Join(UserDir, meta.Run.User)
dataDir := filepath.Join(ProblemDir, fmt.Sprintf("%d", meta.Run.Version), "data", "input") dataDir := filepath.Join(ProblemDir, fmt.Sprintf("%d", meta.Run.Version), "data", "input")
@ -54,10 +61,10 @@ func (s *service) ProblemRun(meta *JudgeMeta) {
Timeout: time.Duration((meta.Cfg.Lang.Runtime.Run.TimeLimit+1000)/1000+1+1) * time.Second, Timeout: time.Duration((meta.Cfg.Lang.Runtime.Run.TimeLimit+1000)/1000+1+1) * time.Second,
} }
ids := make([]int, 0) result := make(ProblemRunResults)
for _, task := range meta.Cfg.All.Tasks { for _, task := range meta.Cfg.All.Tasks {
f := func(id int) func() error { f := func(id int) func() (interface{}, error) {
return func() error { return func() (interface{}, error) {
testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id)) testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id))
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
ifoFile := filepath.Join(workDir, fmt.Sprintf("%d.info", id)) ifoFile := filepath.Join(workDir, fmt.Sprintf("%d.info", id))
@ -109,23 +116,45 @@ func (s *service) ProblemRun(meta *JudgeMeta) {
DoAny(func() error { return os.Remove(ifoFile) }). DoAny(func() error { return os.Remove(ifoFile) }).
Do(func() error { return file.TouchErr(ansFile) }). Do(func() error { return file.TouchErr(ansFile) }).
Do(func() error { return file.TouchErr(ifoFile) }). Do(func() error { return file.TouchErr(ifoFile) }).
Do(func() error { return s.ContainerRun(args) }).
Done() Done()
if err != nil { if err != nil {
s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta)) s.log.Info("[run] prepare failed", zap.Error(err), zap.Any("meta", *meta))
return nil, err
} }
return err
return s.ContainerRun(args)
} }
}(task.Id) }(task.Id)
id := s.pool.AddTask(f) queueId := s.pool.AddTask(f)
ids = append(ids, id) result[task.Id] = &ProblemRunResult{
QueueId: queueId,
}
} }
for _, id := range ids { for i := range result {
_ = s.pool.WaitForTask(id) waitBuf := s.pool.WaitForTask(result[i].QueueId)
if waitBuf.Error != nil {
s.log.Error(
"[run] wait for problem run failed",
zap.Error(waitBuf.Error),
zap.Any("meta", *meta))
continue
}
val, ok := waitBuf.Value.(RuntimeStatus)
if !ok {
s.log.Error(
"[run] container run is not returning RuntimeStatus",
zap.Any("waitBuf", waitBuf),
zap.Any("meta", *meta))
continue
}
result[i].Status = val
} }
return result
} }
func (s *service) ProblemJudge(meta *JudgeMeta) { func (s *service) ProblemJudge(meta *JudgeMeta) {
@ -141,10 +170,10 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
Timeout: time.Duration((meta.Cfg.Lang.Runtime.Check.TimeLimit+1000)/1000) * time.Second, Timeout: time.Duration((meta.Cfg.Lang.Runtime.Check.TimeLimit+1000)/1000) * time.Second,
} }
ids := make([]int, 0) ids := make([]uint64, 0)
for _, task := range meta.Cfg.All.Tasks { for _, task := range meta.Cfg.All.Tasks {
f := func(id int) func() error { f := func(id int) func() (interface{}, error) {
return func() error { return func() (interface{}, error) {
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id)) jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id))
@ -191,13 +220,14 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
err := utils.NewMust(). err := utils.NewMust().
DoAny(func() error { return os.Remove(jdgFile) }). DoAny(func() error { return os.Remove(jdgFile) }).
Do(func() error { return file.TouchErr(jdgFile) }). Do(func() error { return file.TouchErr(jdgFile) }).
Do(func() error { return s.ContainerRun(args) }).
Done() Done()
if err != nil { if err != nil {
s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta)) s.log.Info("[judge] judge prepare failed", zap.Error(err), zap.Any("meta", *meta))
return nil, err
} }
return err
return s.ContainerRun(args)
} }
}(task.Id) }(task.Id)
@ -212,13 +242,13 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
func (s *service) RunAndJudge(meta *JudgeMeta) (*JudgeStatus, int32, e.Status) { func (s *service) RunAndJudge(meta *JudgeMeta) (*JudgeStatus, int32, e.Status) {
// 1. run user program // 1. run user program
s.ProblemRun(meta) results := s.ProblemRun(meta)
// 2. run judge // 2. run judge
s.ProblemJudge(meta) s.ProblemJudge(meta)
// 3. check result // 3. final JudgeStatus
result, pts := s.CheckResults(meta) status, pts := s.CheckResults(meta, results)
return result, pts, e.Success return status, pts, e.Success
} }

View File

@ -31,14 +31,18 @@ type TestLibReport struct {
Result string `xml:",chardata"` Result string `xml:",chardata"`
} }
type RuntimeStatus struct {
RealTime int `json:"real_time"` // in ms
CpuTime int `json:"cpu_time"` // in ms
Memory int `json:"memory"` // in kb
}
type TaskStatus struct { type TaskStatus struct {
Id int `json:"id"` Id int `json:"id"`
Points int32 `json:"points"` Points int32 `json:"points"`
RealTime int `json:"real_time"` Runtime RuntimeStatus `json:"runtime"`
CpuTime int `json:"cpu_time"` Verdict int `json:"verdict"`
Memory int `json:"memory"` Message string `json:"message"`
Verdict int `json:"verdict"`
Message string `json:"message"`
infoText []byte infoText []byte
info map[string]interface{} info map[string]interface{}
@ -52,7 +56,7 @@ type JudgeStatus struct {
Tasks []TaskStatus `json:"tasks"` Tasks []TaskStatus `json:"tasks"`
} }
func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus { func (t *TaskStatus) ReadSandboxInfo(infoFile string) *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
@ -67,7 +71,7 @@ func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus {
return t return t
} }
func (t *TaskStatus) getInfo() *TaskStatus { func (t *TaskStatus) ExtractSandboxInfo() *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
@ -77,20 +81,62 @@ func (t *TaskStatus) getInfo() *TaskStatus {
t.Verdict = VerdictSystemError t.Verdict = VerdictSystemError
t.Message = "cannot parse info file" t.Message = "cannot parse info file"
} else { } else {
t.RealTime = int(t.info["real_time"].(float64)) t.Runtime = RuntimeStatus{
t.CpuTime = int(t.info["cpu_time"].(float64)) RealTime: t.info["real_time"].(int),
t.Memory = int(t.info["memory"].(float64)) CpuTime: t.info["cpu_time"].(int),
Memory: t.info["memory"].(int),
}
} }
return t return t
} }
func (t *TaskStatus) checkExit() *TaskStatus { func (t *TaskStatus) MergeContainerInfo(status *RuntimeStatus) *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
if t.info["status"] != "exited" || t.info["code"] != 0.0 { t.Runtime.RealTime = max(t.Runtime.RealTime, status.RealTime)
t.Runtime.CpuTime = max(t.Runtime.CpuTime, status.CpuTime)
t.Runtime.Memory = max(t.Runtime.Memory, status.Memory)
return t
}
func (t *TaskStatus) CheckTime(cLang *ConfigLanguage) *TaskStatus {
if t.Verdict != VerdictAccepted {
return t
}
if t.Runtime.RealTime > cLang.Runtime.Run.TimeLimit+5 ||
t.Runtime.CpuTime > cLang.Runtime.Run.TimeLimit+5 {
t.Verdict = VerdictTimeLimitExceeded
t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.Runtime.RealTime, t.Runtime.CpuTime)
}
return t
}
func (t *TaskStatus) CheckMemory(cLang *ConfigLanguage) *TaskStatus {
if t.Verdict != VerdictAccepted {
return t
}
// t.Runtime.Memory is in kb
if t.Runtime.Memory > (cLang.Runtime.Run.MemoryLimit+1)*1024 {
t.Verdict = VerdictMemoryLimitExceeded
t.Message = fmt.Sprintf("memory: %v", t.Runtime.Memory)
}
return t
}
func (t *TaskStatus) CheckExitCode() *TaskStatus {
if t.Verdict != VerdictAccepted {
return t
}
if t.info["status"] != "exited" || t.info["code"] != 0 {
t.Verdict = VerdictRuntimeError t.Verdict = VerdictRuntimeError
t.Message = fmt.Sprintf("status: %v, code: %v", t.info["status"], t.info["code"]) t.Message = fmt.Sprintf("status: %v, code: %v", t.info["status"], t.info["code"])
} }
@ -98,33 +144,7 @@ func (t *TaskStatus) checkExit() *TaskStatus {
return t return t
} }
func (t *TaskStatus) checkTime(cLang *ConfigLanguage) *TaskStatus { func (t *TaskStatus) ReadJudgeReport(judgeFile string) *TaskStatus {
if t.Verdict != VerdictAccepted {
return t
}
if t.info["real_time"].(float64) > float64(cLang.Runtime.Run.TimeLimit)+5 {
t.Verdict = VerdictTimeLimitExceeded
t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.info["real_time"], t.info["cpu_time"])
}
return t
}
func (t *TaskStatus) checkMemory(cLang *ConfigLanguage) *TaskStatus {
if t.Verdict != VerdictAccepted {
return t
}
if t.info["memory"].(float64) > float64((cLang.Runtime.Run.MemoryLimit+1)*1024) {
t.Verdict = VerdictMemoryLimitExceeded
t.Message = fmt.Sprintf("memory: %v", t.info["memory"])
}
return t
}
func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
@ -132,7 +152,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
j, err := file.Read(judgeFile) j, err := file.Read(judgeFile)
if err != nil { if err != nil {
t.Verdict = VerdictSystemError t.Verdict = VerdictSystemError
t.Message = "cannot read judge file" t.Message = "cannot read judge report"
} else { } else {
t.judgeText = string(j) t.judgeText = string(j)
} }
@ -140,7 +160,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
return t return t
} }
func (t *TaskStatus) getJudge() *TaskStatus { func (t *TaskStatus) DecodeJudgeReport() *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
@ -159,13 +179,13 @@ func (t *TaskStatus) getJudge() *TaskStatus {
err := d.Decode(&t.judge) err := d.Decode(&t.judge)
if err != nil { if err != nil {
t.Verdict = VerdictSystemError t.Verdict = VerdictSystemError
t.Message = "cannot parse judge file" t.Message = "cannot parse judge report"
} }
return t return t
} }
func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus { func (t *TaskStatus) CheckJudgeReport(pts *map[int]int32) *TaskStatus {
if t.Verdict != VerdictAccepted { if t.Verdict != VerdictAccepted {
return t return t
} }
@ -194,7 +214,7 @@ func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus {
return t return t
} }
func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) { func (s *service) CheckResults(meta *JudgeMeta, prResults ProblemRunResults) (*JudgeStatus, int32) {
// CE will be processed in phase compile // CE will be processed in phase compile
pts := map[int]int32{} pts := map[int]int32{}
@ -212,14 +232,15 @@ func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) {
info := filepath.Join(dir, fmt.Sprintf("%d.info", i)) info := filepath.Join(dir, fmt.Sprintf("%d.info", i))
judge := filepath.Join(dir, fmt.Sprintf("%d.judge", i)) judge := filepath.Join(dir, fmt.Sprintf("%d.judge", i))
result.getInfoText(info). result.ReadSandboxInfo(info).
getInfo(). ExtractSandboxInfo().
checkTime(meta.Cfg.Lang). MergeContainerInfo(&prResults[i].Status).
checkMemory(meta.Cfg.Lang). CheckTime(meta.Cfg.Lang).
checkExit(). CheckMemory(meta.Cfg.Lang).
getJudgeText(judge). CheckExitCode().
getJudge(). ReadJudgeReport(judge).
checkJudge(&pts) DecodeJudgeReport().
CheckJudgeReport(&pts)
sum += result.Points sum += result.Points
results = append(results, result) results = append(results, result)

View File

@ -2,6 +2,7 @@ package pool
import ( import (
"sync" "sync"
"sync/atomic"
) )
type TaskPool struct { type TaskPool struct {
@ -9,9 +10,8 @@ type TaskPool struct {
queue chan Task queue chan Task
wg sync.WaitGroup wg sync.WaitGroup
lck sync.Mutex curTaskID atomic.Uint64
curTaskID int waitMap sync.Map
waitMap map[int]chan error
} }
type ErrTaskNotFound struct{} type ErrTaskNotFound struct{}
@ -20,13 +20,19 @@ func (m *ErrTaskNotFound) Error() string {
return "task not found" return "task not found"
} }
type WaitBuf struct {
Value interface{}
Error error
}
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
return &TaskPool{ tp := &TaskPool{
workers: maxWorkers, workers: maxWorkers,
queue: make(chan Task, bufferSize), queue: make(chan Task, bufferSize),
waitMap: make(map[int]chan error), waitMap: sync.Map{},
curTaskID: 1, // task id starts from 1 curTaskID: atomic.Uint64{},
} }
return tp
} }
func (tp *TaskPool) Start() { func (tp *TaskPool) Start() {
@ -37,16 +43,11 @@ func (tp *TaskPool) Start() {
} }
} }
func (tp *TaskPool) AddTask(f func() error) int { func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 {
tp.lck.Lock() id := tp.curTaskID.Add(1)
id := tp.curTaskID waitChan := make(chan WaitBuf, 1)
tp.curTaskID++ tp.waitMap.Store(id, waitChan)
waitChan := make(chan error, 1)
tp.waitMap[id] = waitChan
tp.lck.Unlock()
task := Task{id: id, f: f} task := Task{id: id, f: f}
tp.queue <- task tp.queue <- task
@ -54,35 +55,29 @@ func (tp *TaskPool) AddTask(f func() error) int {
return id return id
} }
func (tp *TaskPool) WaitForTask(taskID int) error { func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf {
tp.lck.Lock() val, ok := tp.waitMap.Load(taskID)
waitChan, ok := tp.waitMap[taskID]
if !ok { if !ok {
tp.lck.Unlock() return WaitBuf{nil, &ErrTaskNotFound{}}
return &ErrTaskNotFound{}
} }
tp.lck.Unlock() waitChan := val.(chan WaitBuf)
ret := <-waitChan ret := <-waitChan
close(waitChan) close(waitChan)
tp.waitMap.Delete(taskID)
tp.lck.Lock()
delete(tp.waitMap, taskID)
tp.lck.Unlock()
return ret return ret
} }
func (tp *TaskPool) markTaskComplete(taskID int, err error) { func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) {
tp.lck.Lock() val, ok := tp.waitMap.Load(taskID)
waitChan, ok := tp.waitMap[taskID]
if !ok { if !ok {
// should never happen here // should never happen here
panic("worker: task destroyed before completion") panic("worker: task destroyed before completion")
} }
tp.lck.Unlock()
waitChan <- err waitChan := val.(chan WaitBuf)
waitChan <- buf
} }
func (tp *TaskPool) Stop() { func (tp *TaskPool) Stop() {

View File

@ -2,7 +2,6 @@ package pool
import ( import (
"errors" "errors"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -16,8 +15,8 @@ func TestTaskPool_Stop(t *testing.T) {
counter := 0 counter := 0
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() error { f := func(i int) func() (interface{}, error) {
return func() error { return func() (interface{}, error) {
lck.Lock() lck.Lock()
t.Log("task", i, "locked") t.Log("task", i, "locked")
counter += i counter += i
@ -27,7 +26,7 @@ func TestTaskPool_Stop(t *testing.T) {
time.Sleep(time.Duration(i*100) * time.Millisecond) time.Sleep(time.Duration(i*100) * time.Millisecond)
t.Log("task", i, "finished") t.Log("task", i, "finished")
return nil return nil, nil
} }
}(i) }(i)
pool.AddTask(f) pool.AddTask(f)
@ -47,12 +46,12 @@ func TestTaskPool_WaitForTask(t *testing.T) {
counter := 0 counter := 0
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() error { f := func(i int) func() (interface{}, error) {
return func() error { return func() (interface{}, error) {
counter += 1 counter += 1
t.Log("task", i, "finished") t.Log("task", i, "finished")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
return errors.New(strconv.Itoa(i)) return i, nil
} }
}(i) }(i)
id := pool.AddTask(f) id := pool.AddTask(f)
@ -61,8 +60,11 @@ func TestTaskPool_WaitForTask(t *testing.T) {
if counter != 1 { if counter != 1 {
t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id)
} }
if ret.Error() != strconv.Itoa(i) { if ret.Error != nil {
t.Errorf("Return value mismatch: expected %s, got %s, task %d", strconv.Itoa(i), ret.Error(), id) t.Errorf("Return error: %v, task %d", ret.Error, id)
}
if ret.Value.(int) != i {
t.Errorf("Return value mismatch: expected %d, got %v, task %d", i, ret, id)
} }
counter -= 1 counter -= 1
} }
@ -74,21 +76,21 @@ func TestTaskPool_DoubleWait(t *testing.T) {
pool := NewTaskPool(1, 1) pool := NewTaskPool(1, 1)
pool.Start() pool.Start()
f := func() error { f := func() (interface{}, error) {
t.Log("task invoked") t.Log("task invoked")
return nil return nil, nil
} }
id := pool.AddTask(f) id := pool.AddTask(f)
ret := pool.WaitForTask(id) ret := pool.WaitForTask(id)
if ret != nil { if ret.Error != nil {
t.Errorf("task returned error: %v", ret) t.Errorf("task returned error: %v", ret)
} }
ret2 := pool.WaitForTask(id) ret2 := pool.WaitForTask(id)
if ret2 == nil { if ret2.Error == nil {
t.Errorf("2nd wait returned nil") t.Errorf("2nd wait returned nil")
} else if !errors.Is(ret2, &ErrTaskNotFound{}) { } else if !errors.Is(ret2.Error, &ErrTaskNotFound{}) {
t.Errorf("2nd wait returned wrong error: %v", ret2) t.Errorf("2nd wait returned wrong error: %v", ret2)
} }
@ -102,10 +104,10 @@ func TestTaskPool_One(t *testing.T) {
lck := sync.Mutex{} lck := sync.Mutex{}
counter := 0 counter := 0
ids := make([]int, 0) ids := make([]uint64, 0)
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() error { f := func(i int) func() (interface{}, error) {
return func() error { return func() (interface{}, error) {
lck.Lock() lck.Lock()
t.Log("task", i, "locked") t.Log("task", i, "locked")
counter += i counter += i
@ -115,7 +117,7 @@ func TestTaskPool_One(t *testing.T) {
time.Sleep(time.Duration(i*10) * time.Millisecond) time.Sleep(time.Duration(i*10) * time.Millisecond)
t.Log("task", i, "finished") t.Log("task", i, "finished")
return nil return nil, nil
} }
}(i) }(i)
id := pool.AddTask(f) id := pool.AddTask(f)

View File

@ -1,6 +1,6 @@
package pool package pool
type Task struct { type Task struct {
id int id uint64
f func() error f func() (interface{}, error)
} }

View File

@ -18,7 +18,7 @@ func (w *Worker) Start(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
for task := range w.queue { for task := range w.queue {
err := task.f() val, err := task.f()
w.pool.markTaskComplete(task.id, err) w.pool.markTaskComplete(task.id, WaitBuf{Value: val, Error: err})
} }
} }