From 6956fe4ee1b33d072d5c79f9bbc6b369758e2e53 Mon Sep 17 00:00:00 2001 From: Paul Pan Date: Thu, 15 Feb 2024 12:53:57 +0800 Subject: [PATCH] feat: capture runtime status from cgroups pkg/pool: task is now available to return interface{} as result pkg/pool: use atomic instead of mutex service/runner: ContainerRun will return metrics --- internal/service/runner/compile.go | 3 +- internal/service/runner/config.go | 4 +- internal/service/runner/container.go | 60 ++++++++++-- internal/service/runner/new_problem.go | 6 +- internal/service/runner/run_judge.go | 72 ++++++++++---- internal/service/runner/status.go | 129 ++++++++++++++----------- pkg/pool/pool.go | 55 +++++------ pkg/pool/pool_test.go | 38 ++++---- pkg/pool/task.go | 4 +- pkg/pool/worker.go | 4 +- 10 files changed, 232 insertions(+), 143 deletions(-) diff --git a/internal/service/runner/compile.go b/internal/service/runner/compile.go index d4fca62..a434786 100644 --- a/internal/service/runner/compile.go +++ b/internal/service/runner/compile.go @@ -85,7 +85,8 @@ func (s *service) Compile(meta *JudgeMeta) (*JudgeStatus, e.Status) { } id := s.ContainerRunPool(args) - return s.pool.WaitForTask(id) + ret := s.pool.WaitForTask(id) + return ret.Error }). Done() diff --git a/internal/service/runner/config.go b/internal/service/runner/config.go index a6b33fb..b6b9298 100644 --- a/internal/service/runner/config.go +++ b/internal/service/runner/config.go @@ -22,8 +22,8 @@ var ( ) type ConfigRuntime struct { - TimeLimit int `json:"TimeLimit"` - MemoryLimit int `json:"MemoryLimit"` + TimeLimit int `json:"TimeLimit"` // in ms + MemoryLimit int `json:"MemoryLimit"` // in mb NProcLimit int `json:"NProcLimit"` } diff --git a/internal/service/runner/container.go b/internal/service/runner/container.go index 917350d..518d80c 100644 --- a/internal/service/runner/container.go +++ b/internal/service/runner/container.go @@ -2,12 +2,16 @@ package runner import ( "context" + "errors" "fmt" "git.0x7f.app/WOJ/woj-server/pkg/file" "git.0x7f.app/WOJ/woj-server/pkg/utils" + cgv1 "github.com/containerd/cgroups/v3/cgroup1/stats" + cgv2 "github.com/containerd/cgroups/v3/cgroup2/stats" "github.com/containerd/containerd" "github.com/containerd/containerd/cio" "github.com/containerd/containerd/oci" + "github.com/containerd/typeurl/v2" "github.com/opencontainers/runtime-spec/specs-go" "go.uber.org/zap" "io" @@ -49,7 +53,7 @@ type RunArgs struct { IO IOArgs } -func (s *service) ContainerRun(arg *RunArgs) error { +func (s *service) ContainerRun(arg *RunArgs) (RuntimeStatus, error) { identifier := fmt.Sprintf("%d", s.container.count.Add(1)) // prepare args @@ -76,11 +80,13 @@ func (s *service) ContainerRun(arg *RunArgs) error { image, err := s.container.client.GetImage(s.container.ctx, arg.Runtime.Image) // TODO: we could cache the image struct if err != nil { - return err + return RuntimeStatus{}, err } // create container + // TODO: new container is taking too long, could we cache the container struct? container, err := s.container.client.NewContainer(s.container.ctx, "task-"+identifier, + // TODO: should we use RO snapshot? containerd.WithNewSnapshot("snapshot-"+identifier, image), containerd.WithNewSpec( oci.WithImageConfig(image), @@ -92,7 +98,7 @@ func (s *service) ContainerRun(arg *RunArgs) error { ), ) if err != nil { - return err + return RuntimeStatus{}, err } defer func(container containerd.Container, ctx context.Context, opts ...containerd.DeleteOpts) { _ = container.Delete(ctx, opts...) @@ -101,7 +107,7 @@ func (s *service) ContainerRun(arg *RunArgs) error { // create task task, err := container.NewTask(s.container.ctx, cio.NewCreator(cio.WithStreams(nil, writer, writer))) if err != nil { - return err + return RuntimeStatus{}, err } defer func(task containerd.Task, ctx context.Context, opts ...containerd.ProcessDeleteOpts) { _, _ = task.Delete(ctx, opts...) @@ -112,13 +118,13 @@ func (s *service) ContainerRun(arg *RunArgs) error { defer cancel() exitStatusC, err := task.Wait(ctx2) if err != nil { - return err + return RuntimeStatus{}, err } // start err = task.Start(s.container.ctx) if err != nil { - return err + return RuntimeStatus{}, err } // kill on timeout @@ -130,13 +136,47 @@ func (s *service) ContainerRun(arg *RunArgs) error { s.log.Debug("container timeout", zap.String("identifier", identifier)) err := task.Kill(s.container.ctx, syscall.SIGKILL) if err != nil { - return err + return RuntimeStatus{}, err } } - return nil + // get metrics + metric, err := task.Metrics(s.container.ctx) + if err != nil { + return RuntimeStatus{}, err + } + + // modified from github.com/containerd/containerd/cmd/ctr/commands/tasks/metrics.go + var data interface{} + switch { + case typeurl.Is(metric.Data, (*cgv1.Metrics)(nil)): + data = &cgv1.Metrics{} + case typeurl.Is(metric.Data, (*cgv2.Metrics)(nil)): + data = &cgv2.Metrics{} + default: + return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.Metrics") + } + if err := typeurl.UnmarshalTo(metric.Data, data); err != nil { + return RuntimeStatus{}, err + } + + runtime := RuntimeStatus{} + switch v := data.(type) { + case *cgv1.Metrics: + runtime.CpuTime = int(v.CPU.Usage.Total / 1000000) // nanoseconds to milliseconds + runtime.Memory = int(v.Memory.Usage.Max / 1024) // bytes to kilobytes + runtime.RealTime = runtime.CpuTime + case *cgv2.Metrics: + runtime.CpuTime = int(v.CPU.UsageUsec / 1000) // microseconds to milliseconds + runtime.Memory = int(v.Memory.MaxUsage / 1024) // bytes to kilobytes + runtime.RealTime = runtime.CpuTime + default: + return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.{v1/v2}.Metrics") + } + + return runtime, nil } -func (s *service) ContainerRunPool(arg *RunArgs) int { - return s.pool.AddTask(func() error { return s.ContainerRun(arg) }) +func (s *service) ContainerRunPool(arg *RunArgs) uint64 { + return s.pool.AddTask(func() (interface{}, error) { return s.ContainerRun(arg) }) } diff --git a/internal/service/runner/new_problem.go b/internal/service/runner/new_problem.go index b1fe922..cddb298 100644 --- a/internal/service/runner/new_problem.go +++ b/internal/service/runner/new_problem.go @@ -80,10 +80,10 @@ func (s *service) PrebuildProblem(meta *JudgeMeta, config *Config, force bool) e } id := s.ContainerRunPool(args) - err := s.pool.WaitForTask(id) + ret := s.pool.WaitForTask(id) - if err != nil { - s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Run.Version)) + if ret.Error != nil { + s.log.Warn("[new] prebuild problem failed", zap.Any("ret", ret), zap.Uint("version", meta.Run.Version)) return e.RunnerProblemPrebuildFailed } diff --git a/internal/service/runner/run_judge.go b/internal/service/runner/run_judge.go index 173a3f2..896813a 100644 --- a/internal/service/runner/run_judge.go +++ b/internal/service/runner/run_judge.go @@ -13,6 +13,13 @@ import ( "time" ) +type ProblemRunResult struct { + QueueId uint64 + Status RuntimeStatus +} + +type ProblemRunResults map[int]*ProblemRunResult + func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string { var args []string @@ -41,7 +48,7 @@ func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string { return strings.Join(args, " ") } -func (s *service) ProblemRun(meta *JudgeMeta) { +func (s *service) ProblemRun(meta *JudgeMeta) ProblemRunResults { workDir := filepath.Join(UserDir, meta.Run.User) dataDir := filepath.Join(ProblemDir, fmt.Sprintf("%d", meta.Run.Version), "data", "input") @@ -54,10 +61,10 @@ func (s *service) ProblemRun(meta *JudgeMeta) { Timeout: time.Duration((meta.Cfg.Lang.Runtime.Run.TimeLimit+1000)/1000+1+1) * time.Second, } - ids := make([]int, 0) + result := make(ProblemRunResults) for _, task := range meta.Cfg.All.Tasks { - f := func(id int) func() error { - return func() error { + f := func(id int) func() (interface{}, error) { + return func() (interface{}, error) { testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) ifoFile := filepath.Join(workDir, fmt.Sprintf("%d.info", id)) @@ -109,23 +116,45 @@ func (s *service) ProblemRun(meta *JudgeMeta) { DoAny(func() error { return os.Remove(ifoFile) }). Do(func() error { return file.TouchErr(ansFile) }). Do(func() error { return file.TouchErr(ifoFile) }). - Do(func() error { return s.ContainerRun(args) }). Done() if err != nil { - s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta)) + s.log.Info("[run] prepare failed", zap.Error(err), zap.Any("meta", *meta)) + return nil, err } - return err + + return s.ContainerRun(args) } }(task.Id) - id := s.pool.AddTask(f) - ids = append(ids, id) + queueId := s.pool.AddTask(f) + result[task.Id] = &ProblemRunResult{ + QueueId: queueId, + } } - for _, id := range ids { - _ = s.pool.WaitForTask(id) + for i := range result { + waitBuf := s.pool.WaitForTask(result[i].QueueId) + if waitBuf.Error != nil { + s.log.Error( + "[run] wait for problem run failed", + zap.Error(waitBuf.Error), + zap.Any("meta", *meta)) + continue + } + val, ok := waitBuf.Value.(RuntimeStatus) + if !ok { + s.log.Error( + "[run] container run is not returning RuntimeStatus", + zap.Any("waitBuf", waitBuf), + zap.Any("meta", *meta)) + continue + } + + result[i].Status = val } + + return result } func (s *service) ProblemJudge(meta *JudgeMeta) { @@ -141,10 +170,10 @@ func (s *service) ProblemJudge(meta *JudgeMeta) { Timeout: time.Duration((meta.Cfg.Lang.Runtime.Check.TimeLimit+1000)/1000) * time.Second, } - ids := make([]int, 0) + ids := make([]uint64, 0) for _, task := range meta.Cfg.All.Tasks { - f := func(id int) func() error { - return func() error { + f := func(id int) func() (interface{}, error) { + return func() (interface{}, error) { ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id)) @@ -191,13 +220,14 @@ func (s *service) ProblemJudge(meta *JudgeMeta) { err := utils.NewMust(). DoAny(func() error { return os.Remove(jdgFile) }). Do(func() error { return file.TouchErr(jdgFile) }). - Do(func() error { return s.ContainerRun(args) }). Done() if err != nil { - s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta)) + s.log.Info("[judge] judge prepare failed", zap.Error(err), zap.Any("meta", *meta)) + return nil, err } - return err + + return s.ContainerRun(args) } }(task.Id) @@ -212,13 +242,13 @@ func (s *service) ProblemJudge(meta *JudgeMeta) { func (s *service) RunAndJudge(meta *JudgeMeta) (*JudgeStatus, int32, e.Status) { // 1. run user program - s.ProblemRun(meta) + results := s.ProblemRun(meta) // 2. run judge s.ProblemJudge(meta) - // 3. check result - result, pts := s.CheckResults(meta) + // 3. final JudgeStatus + status, pts := s.CheckResults(meta, results) - return result, pts, e.Success + return status, pts, e.Success } diff --git a/internal/service/runner/status.go b/internal/service/runner/status.go index 397551f..c921b1e 100644 --- a/internal/service/runner/status.go +++ b/internal/service/runner/status.go @@ -31,14 +31,18 @@ type TestLibReport struct { Result string `xml:",chardata"` } +type RuntimeStatus struct { + RealTime int `json:"real_time"` // in ms + CpuTime int `json:"cpu_time"` // in ms + Memory int `json:"memory"` // in kb +} + type TaskStatus struct { - Id int `json:"id"` - Points int32 `json:"points"` - RealTime int `json:"real_time"` - CpuTime int `json:"cpu_time"` - Memory int `json:"memory"` - Verdict int `json:"verdict"` - Message string `json:"message"` + Id int `json:"id"` + Points int32 `json:"points"` + Runtime RuntimeStatus `json:"runtime"` + Verdict int `json:"verdict"` + Message string `json:"message"` infoText []byte info map[string]interface{} @@ -52,7 +56,7 @@ type JudgeStatus struct { Tasks []TaskStatus `json:"tasks"` } -func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus { +func (t *TaskStatus) ReadSandboxInfo(infoFile string) *TaskStatus { if t.Verdict != VerdictAccepted { return t } @@ -67,7 +71,7 @@ func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus { return t } -func (t *TaskStatus) getInfo() *TaskStatus { +func (t *TaskStatus) ExtractSandboxInfo() *TaskStatus { if t.Verdict != VerdictAccepted { return t } @@ -77,20 +81,62 @@ func (t *TaskStatus) getInfo() *TaskStatus { t.Verdict = VerdictSystemError t.Message = "cannot parse info file" } else { - t.RealTime = int(t.info["real_time"].(float64)) - t.CpuTime = int(t.info["cpu_time"].(float64)) - t.Memory = int(t.info["memory"].(float64)) + t.Runtime = RuntimeStatus{ + RealTime: t.info["real_time"].(int), + CpuTime: t.info["cpu_time"].(int), + Memory: t.info["memory"].(int), + } } return t } -func (t *TaskStatus) checkExit() *TaskStatus { +func (t *TaskStatus) MergeContainerInfo(status *RuntimeStatus) *TaskStatus { if t.Verdict != VerdictAccepted { return t } - if t.info["status"] != "exited" || t.info["code"] != 0.0 { + t.Runtime.RealTime = max(t.Runtime.RealTime, status.RealTime) + t.Runtime.CpuTime = max(t.Runtime.CpuTime, status.CpuTime) + t.Runtime.Memory = max(t.Runtime.Memory, status.Memory) + + return t +} + +func (t *TaskStatus) CheckTime(cLang *ConfigLanguage) *TaskStatus { + if t.Verdict != VerdictAccepted { + return t + } + + if t.Runtime.RealTime > cLang.Runtime.Run.TimeLimit+5 || + t.Runtime.CpuTime > cLang.Runtime.Run.TimeLimit+5 { + t.Verdict = VerdictTimeLimitExceeded + t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.Runtime.RealTime, t.Runtime.CpuTime) + } + + return t +} + +func (t *TaskStatus) CheckMemory(cLang *ConfigLanguage) *TaskStatus { + if t.Verdict != VerdictAccepted { + return t + } + + // t.Runtime.Memory is in kb + if t.Runtime.Memory > (cLang.Runtime.Run.MemoryLimit+1)*1024 { + t.Verdict = VerdictMemoryLimitExceeded + t.Message = fmt.Sprintf("memory: %v", t.Runtime.Memory) + } + + return t +} + +func (t *TaskStatus) CheckExitCode() *TaskStatus { + if t.Verdict != VerdictAccepted { + return t + } + + if t.info["status"] != "exited" || t.info["code"] != 0 { t.Verdict = VerdictRuntimeError t.Message = fmt.Sprintf("status: %v, code: %v", t.info["status"], t.info["code"]) } @@ -98,33 +144,7 @@ func (t *TaskStatus) checkExit() *TaskStatus { return t } -func (t *TaskStatus) checkTime(cLang *ConfigLanguage) *TaskStatus { - if t.Verdict != VerdictAccepted { - return t - } - - if t.info["real_time"].(float64) > float64(cLang.Runtime.Run.TimeLimit)+5 { - t.Verdict = VerdictTimeLimitExceeded - t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.info["real_time"], t.info["cpu_time"]) - } - - return t -} - -func (t *TaskStatus) checkMemory(cLang *ConfigLanguage) *TaskStatus { - if t.Verdict != VerdictAccepted { - return t - } - - if t.info["memory"].(float64) > float64((cLang.Runtime.Run.MemoryLimit+1)*1024) { - t.Verdict = VerdictMemoryLimitExceeded - t.Message = fmt.Sprintf("memory: %v", t.info["memory"]) - } - - return t -} - -func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus { +func (t *TaskStatus) ReadJudgeReport(judgeFile string) *TaskStatus { if t.Verdict != VerdictAccepted { return t } @@ -132,7 +152,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus { j, err := file.Read(judgeFile) if err != nil { t.Verdict = VerdictSystemError - t.Message = "cannot read judge file" + t.Message = "cannot read judge report" } else { t.judgeText = string(j) } @@ -140,7 +160,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus { return t } -func (t *TaskStatus) getJudge() *TaskStatus { +func (t *TaskStatus) DecodeJudgeReport() *TaskStatus { if t.Verdict != VerdictAccepted { return t } @@ -159,13 +179,13 @@ func (t *TaskStatus) getJudge() *TaskStatus { err := d.Decode(&t.judge) if err != nil { t.Verdict = VerdictSystemError - t.Message = "cannot parse judge file" + t.Message = "cannot parse judge report" } return t } -func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus { +func (t *TaskStatus) CheckJudgeReport(pts *map[int]int32) *TaskStatus { if t.Verdict != VerdictAccepted { return t } @@ -194,7 +214,7 @@ func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus { return t } -func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) { +func (s *service) CheckResults(meta *JudgeMeta, prResults ProblemRunResults) (*JudgeStatus, int32) { // CE will be processed in phase compile pts := map[int]int32{} @@ -212,14 +232,15 @@ func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) { info := filepath.Join(dir, fmt.Sprintf("%d.info", i)) judge := filepath.Join(dir, fmt.Sprintf("%d.judge", i)) - result.getInfoText(info). - getInfo(). - checkTime(meta.Cfg.Lang). - checkMemory(meta.Cfg.Lang). - checkExit(). - getJudgeText(judge). - getJudge(). - checkJudge(&pts) + result.ReadSandboxInfo(info). + ExtractSandboxInfo(). + MergeContainerInfo(&prResults[i].Status). + CheckTime(meta.Cfg.Lang). + CheckMemory(meta.Cfg.Lang). + CheckExitCode(). + ReadJudgeReport(judge). + DecodeJudgeReport(). + CheckJudgeReport(&pts) sum += result.Points results = append(results, result) diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index 4bc1fcd..7b8d758 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -2,6 +2,7 @@ package pool import ( "sync" + "sync/atomic" ) type TaskPool struct { @@ -9,9 +10,8 @@ type TaskPool struct { queue chan Task wg sync.WaitGroup - lck sync.Mutex - curTaskID int - waitMap map[int]chan error + curTaskID atomic.Uint64 + waitMap sync.Map } type ErrTaskNotFound struct{} @@ -20,13 +20,19 @@ func (m *ErrTaskNotFound) Error() string { return "task not found" } +type WaitBuf struct { + Value interface{} + Error error +} + func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { - return &TaskPool{ + tp := &TaskPool{ workers: maxWorkers, queue: make(chan Task, bufferSize), - waitMap: make(map[int]chan error), - curTaskID: 1, // task id starts from 1 + waitMap: sync.Map{}, + curTaskID: atomic.Uint64{}, } + return tp } func (tp *TaskPool) Start() { @@ -37,16 +43,11 @@ func (tp *TaskPool) Start() { } } -func (tp *TaskPool) AddTask(f func() error) int { - tp.lck.Lock() +func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 { + id := tp.curTaskID.Add(1) - id := tp.curTaskID - tp.curTaskID++ - - waitChan := make(chan error, 1) - tp.waitMap[id] = waitChan - - tp.lck.Unlock() + waitChan := make(chan WaitBuf, 1) + tp.waitMap.Store(id, waitChan) task := Task{id: id, f: f} tp.queue <- task @@ -54,35 +55,29 @@ func (tp *TaskPool) AddTask(f func() error) int { return id } -func (tp *TaskPool) WaitForTask(taskID int) error { - tp.lck.Lock() - waitChan, ok := tp.waitMap[taskID] +func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf { + val, ok := tp.waitMap.Load(taskID) if !ok { - tp.lck.Unlock() - return &ErrTaskNotFound{} + return WaitBuf{nil, &ErrTaskNotFound{}} } - tp.lck.Unlock() + waitChan := val.(chan WaitBuf) ret := <-waitChan close(waitChan) - - tp.lck.Lock() - delete(tp.waitMap, taskID) - tp.lck.Unlock() + tp.waitMap.Delete(taskID) return ret } -func (tp *TaskPool) markTaskComplete(taskID int, err error) { - tp.lck.Lock() - waitChan, ok := tp.waitMap[taskID] +func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) { + val, ok := tp.waitMap.Load(taskID) if !ok { // should never happen here panic("worker: task destroyed before completion") } - tp.lck.Unlock() - waitChan <- err + waitChan := val.(chan WaitBuf) + waitChan <- buf } func (tp *TaskPool) Stop() { diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go index fdbfa00..c665679 100644 --- a/pkg/pool/pool_test.go +++ b/pkg/pool/pool_test.go @@ -2,7 +2,6 @@ package pool import ( "errors" - "strconv" "sync" "testing" "time" @@ -16,8 +15,8 @@ func TestTaskPool_Stop(t *testing.T) { counter := 0 for i := 1; i <= 10; i++ { - f := func(i int) func() error { - return func() error { + f := func(i int) func() (interface{}, error) { + return func() (interface{}, error) { lck.Lock() t.Log("task", i, "locked") counter += i @@ -27,7 +26,7 @@ func TestTaskPool_Stop(t *testing.T) { time.Sleep(time.Duration(i*100) * time.Millisecond) t.Log("task", i, "finished") - return nil + return nil, nil } }(i) pool.AddTask(f) @@ -47,12 +46,12 @@ func TestTaskPool_WaitForTask(t *testing.T) { counter := 0 for i := 1; i <= 10; i++ { - f := func(i int) func() error { - return func() error { + f := func(i int) func() (interface{}, error) { + return func() (interface{}, error) { counter += 1 t.Log("task", i, "finished") time.Sleep(100 * time.Millisecond) - return errors.New(strconv.Itoa(i)) + return i, nil } }(i) id := pool.AddTask(f) @@ -61,8 +60,11 @@ func TestTaskPool_WaitForTask(t *testing.T) { if counter != 1 { t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) } - if ret.Error() != strconv.Itoa(i) { - t.Errorf("Return value mismatch: expected %s, got %s, task %d", strconv.Itoa(i), ret.Error(), id) + if ret.Error != nil { + t.Errorf("Return error: %v, task %d", ret.Error, id) + } + if ret.Value.(int) != i { + t.Errorf("Return value mismatch: expected %d, got %v, task %d", i, ret, id) } counter -= 1 } @@ -74,21 +76,21 @@ func TestTaskPool_DoubleWait(t *testing.T) { pool := NewTaskPool(1, 1) pool.Start() - f := func() error { + f := func() (interface{}, error) { t.Log("task invoked") - return nil + return nil, nil } id := pool.AddTask(f) ret := pool.WaitForTask(id) - if ret != nil { + if ret.Error != nil { t.Errorf("task returned error: %v", ret) } ret2 := pool.WaitForTask(id) - if ret2 == nil { + if ret2.Error == nil { t.Errorf("2nd wait returned nil") - } else if !errors.Is(ret2, &ErrTaskNotFound{}) { + } else if !errors.Is(ret2.Error, &ErrTaskNotFound{}) { t.Errorf("2nd wait returned wrong error: %v", ret2) } @@ -102,10 +104,10 @@ func TestTaskPool_One(t *testing.T) { lck := sync.Mutex{} counter := 0 - ids := make([]int, 0) + ids := make([]uint64, 0) for i := 1; i <= 10; i++ { - f := func(i int) func() error { - return func() error { + f := func(i int) func() (interface{}, error) { + return func() (interface{}, error) { lck.Lock() t.Log("task", i, "locked") counter += i @@ -115,7 +117,7 @@ func TestTaskPool_One(t *testing.T) { time.Sleep(time.Duration(i*10) * time.Millisecond) t.Log("task", i, "finished") - return nil + return nil, nil } }(i) id := pool.AddTask(f) diff --git a/pkg/pool/task.go b/pkg/pool/task.go index 7efa300..c1147f9 100644 --- a/pkg/pool/task.go +++ b/pkg/pool/task.go @@ -1,6 +1,6 @@ package pool type Task struct { - id int - f func() error + id uint64 + f func() (interface{}, error) } diff --git a/pkg/pool/worker.go b/pkg/pool/worker.go index 55e8fef..0b568d7 100644 --- a/pkg/pool/worker.go +++ b/pkg/pool/worker.go @@ -18,7 +18,7 @@ func (w *Worker) Start(wg *sync.WaitGroup) { defer wg.Done() for task := range w.queue { - err := task.f() - w.pool.markTaskComplete(task.id, err) + val, err := task.f() + w.pool.markTaskComplete(task.id, WaitBuf{Value: val, Error: err}) } }