feat: capture runtime status from cgroups
pkg/pool: task is now available to return interface{} as result pkg/pool: use atomic instead of mutex service/runner: ContainerRun will return metrics
This commit is contained in:
parent
95e861fe43
commit
6956fe4ee1
@ -85,7 +85,8 @@ func (s *service) Compile(meta *JudgeMeta) (*JudgeStatus, e.Status) {
|
||||
}
|
||||
|
||||
id := s.ContainerRunPool(args)
|
||||
return s.pool.WaitForTask(id)
|
||||
ret := s.pool.WaitForTask(id)
|
||||
return ret.Error
|
||||
}).
|
||||
Done()
|
||||
|
||||
|
@ -22,8 +22,8 @@ var (
|
||||
)
|
||||
|
||||
type ConfigRuntime struct {
|
||||
TimeLimit int `json:"TimeLimit"`
|
||||
MemoryLimit int `json:"MemoryLimit"`
|
||||
TimeLimit int `json:"TimeLimit"` // in ms
|
||||
MemoryLimit int `json:"MemoryLimit"` // in mb
|
||||
NProcLimit int `json:"NProcLimit"`
|
||||
}
|
||||
|
||||
|
@ -2,12 +2,16 @@ package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.0x7f.app/WOJ/woj-server/pkg/file"
|
||||
"git.0x7f.app/WOJ/woj-server/pkg/utils"
|
||||
cgv1 "github.com/containerd/cgroups/v3/cgroup1/stats"
|
||||
cgv2 "github.com/containerd/cgroups/v3/cgroup2/stats"
|
||||
"github.com/containerd/containerd"
|
||||
"github.com/containerd/containerd/cio"
|
||||
"github.com/containerd/containerd/oci"
|
||||
"github.com/containerd/typeurl/v2"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
@ -49,7 +53,7 @@ type RunArgs struct {
|
||||
IO IOArgs
|
||||
}
|
||||
|
||||
func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
func (s *service) ContainerRun(arg *RunArgs) (RuntimeStatus, error) {
|
||||
identifier := fmt.Sprintf("%d", s.container.count.Add(1))
|
||||
|
||||
// prepare args
|
||||
@ -76,11 +80,13 @@ func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
image, err := s.container.client.GetImage(s.container.ctx, arg.Runtime.Image)
|
||||
// TODO: we could cache the image struct
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
|
||||
// create container
|
||||
// TODO: new container is taking too long, could we cache the container struct?
|
||||
container, err := s.container.client.NewContainer(s.container.ctx, "task-"+identifier,
|
||||
// TODO: should we use RO snapshot?
|
||||
containerd.WithNewSnapshot("snapshot-"+identifier, image),
|
||||
containerd.WithNewSpec(
|
||||
oci.WithImageConfig(image),
|
||||
@ -92,7 +98,7 @@ func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
defer func(container containerd.Container, ctx context.Context, opts ...containerd.DeleteOpts) {
|
||||
_ = container.Delete(ctx, opts...)
|
||||
@ -101,7 +107,7 @@ func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
// create task
|
||||
task, err := container.NewTask(s.container.ctx, cio.NewCreator(cio.WithStreams(nil, writer, writer)))
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
defer func(task containerd.Task, ctx context.Context, opts ...containerd.ProcessDeleteOpts) {
|
||||
_, _ = task.Delete(ctx, opts...)
|
||||
@ -112,13 +118,13 @@ func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
defer cancel()
|
||||
exitStatusC, err := task.Wait(ctx2)
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
|
||||
// start
|
||||
err = task.Start(s.container.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
|
||||
// kill on timeout
|
||||
@ -130,13 +136,47 @@ func (s *service) ContainerRun(arg *RunArgs) error {
|
||||
s.log.Debug("container timeout", zap.String("identifier", identifier))
|
||||
err := task.Kill(s.container.ctx, syscall.SIGKILL)
|
||||
if err != nil {
|
||||
return err
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// get metrics
|
||||
metric, err := task.Metrics(s.container.ctx)
|
||||
if err != nil {
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
|
||||
// modified from github.com/containerd/containerd/cmd/ctr/commands/tasks/metrics.go
|
||||
var data interface{}
|
||||
switch {
|
||||
case typeurl.Is(metric.Data, (*cgv1.Metrics)(nil)):
|
||||
data = &cgv1.Metrics{}
|
||||
case typeurl.Is(metric.Data, (*cgv2.Metrics)(nil)):
|
||||
data = &cgv2.Metrics{}
|
||||
default:
|
||||
return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.Metrics")
|
||||
}
|
||||
if err := typeurl.UnmarshalTo(metric.Data, data); err != nil {
|
||||
return RuntimeStatus{}, err
|
||||
}
|
||||
|
||||
runtime := RuntimeStatus{}
|
||||
switch v := data.(type) {
|
||||
case *cgv1.Metrics:
|
||||
runtime.CpuTime = int(v.CPU.Usage.Total / 1000000) // nanoseconds to milliseconds
|
||||
runtime.Memory = int(v.Memory.Usage.Max / 1024) // bytes to kilobytes
|
||||
runtime.RealTime = runtime.CpuTime
|
||||
case *cgv2.Metrics:
|
||||
runtime.CpuTime = int(v.CPU.UsageUsec / 1000) // microseconds to milliseconds
|
||||
runtime.Memory = int(v.Memory.MaxUsage / 1024) // bytes to kilobytes
|
||||
runtime.RealTime = runtime.CpuTime
|
||||
default:
|
||||
return RuntimeStatus{}, errors.New("cannot convert metric data to cgroups.{v1/v2}.Metrics")
|
||||
}
|
||||
|
||||
return runtime, nil
|
||||
}
|
||||
|
||||
func (s *service) ContainerRunPool(arg *RunArgs) int {
|
||||
return s.pool.AddTask(func() error { return s.ContainerRun(arg) })
|
||||
func (s *service) ContainerRunPool(arg *RunArgs) uint64 {
|
||||
return s.pool.AddTask(func() (interface{}, error) { return s.ContainerRun(arg) })
|
||||
}
|
||||
|
@ -80,10 +80,10 @@ func (s *service) PrebuildProblem(meta *JudgeMeta, config *Config, force bool) e
|
||||
}
|
||||
|
||||
id := s.ContainerRunPool(args)
|
||||
err := s.pool.WaitForTask(id)
|
||||
ret := s.pool.WaitForTask(id)
|
||||
|
||||
if err != nil {
|
||||
s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Run.Version))
|
||||
if ret.Error != nil {
|
||||
s.log.Warn("[new] prebuild problem failed", zap.Any("ret", ret), zap.Uint("version", meta.Run.Version))
|
||||
return e.RunnerProblemPrebuildFailed
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProblemRunResult struct {
|
||||
QueueId uint64
|
||||
Status RuntimeStatus
|
||||
}
|
||||
|
||||
type ProblemRunResults map[int]*ProblemRunResult
|
||||
|
||||
func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string {
|
||||
var args []string
|
||||
|
||||
@ -41,7 +48,7 @@ func (s *service) SandboxArgsBuilder(meta *JudgeMeta, id int) string {
|
||||
return strings.Join(args, " ")
|
||||
}
|
||||
|
||||
func (s *service) ProblemRun(meta *JudgeMeta) {
|
||||
func (s *service) ProblemRun(meta *JudgeMeta) ProblemRunResults {
|
||||
workDir := filepath.Join(UserDir, meta.Run.User)
|
||||
dataDir := filepath.Join(ProblemDir, fmt.Sprintf("%d", meta.Run.Version), "data", "input")
|
||||
|
||||
@ -54,10 +61,10 @@ func (s *service) ProblemRun(meta *JudgeMeta) {
|
||||
Timeout: time.Duration((meta.Cfg.Lang.Runtime.Run.TimeLimit+1000)/1000+1+1) * time.Second,
|
||||
}
|
||||
|
||||
ids := make([]int, 0)
|
||||
result := make(ProblemRunResults)
|
||||
for _, task := range meta.Cfg.All.Tasks {
|
||||
f := func(id int) func() error {
|
||||
return func() error {
|
||||
f := func(id int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id))
|
||||
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
|
||||
ifoFile := filepath.Join(workDir, fmt.Sprintf("%d.info", id))
|
||||
@ -109,23 +116,45 @@ func (s *service) ProblemRun(meta *JudgeMeta) {
|
||||
DoAny(func() error { return os.Remove(ifoFile) }).
|
||||
Do(func() error { return file.TouchErr(ansFile) }).
|
||||
Do(func() error { return file.TouchErr(ifoFile) }).
|
||||
Do(func() error { return s.ContainerRun(args) }).
|
||||
Done()
|
||||
|
||||
if err != nil {
|
||||
s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta))
|
||||
s.log.Info("[run] prepare failed", zap.Error(err), zap.Any("meta", *meta))
|
||||
return nil, err
|
||||
}
|
||||
return err
|
||||
|
||||
return s.ContainerRun(args)
|
||||
}
|
||||
}(task.Id)
|
||||
|
||||
id := s.pool.AddTask(f)
|
||||
ids = append(ids, id)
|
||||
queueId := s.pool.AddTask(f)
|
||||
result[task.Id] = &ProblemRunResult{
|
||||
QueueId: queueId,
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
_ = s.pool.WaitForTask(id)
|
||||
for i := range result {
|
||||
waitBuf := s.pool.WaitForTask(result[i].QueueId)
|
||||
if waitBuf.Error != nil {
|
||||
s.log.Error(
|
||||
"[run] wait for problem run failed",
|
||||
zap.Error(waitBuf.Error),
|
||||
zap.Any("meta", *meta))
|
||||
continue
|
||||
}
|
||||
val, ok := waitBuf.Value.(RuntimeStatus)
|
||||
if !ok {
|
||||
s.log.Error(
|
||||
"[run] container run is not returning RuntimeStatus",
|
||||
zap.Any("waitBuf", waitBuf),
|
||||
zap.Any("meta", *meta))
|
||||
continue
|
||||
}
|
||||
|
||||
result[i].Status = val
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *service) ProblemJudge(meta *JudgeMeta) {
|
||||
@ -141,10 +170,10 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
|
||||
Timeout: time.Duration((meta.Cfg.Lang.Runtime.Check.TimeLimit+1000)/1000) * time.Second,
|
||||
}
|
||||
|
||||
ids := make([]int, 0)
|
||||
ids := make([]uint64, 0)
|
||||
for _, task := range meta.Cfg.All.Tasks {
|
||||
f := func(id int) func() error {
|
||||
return func() error {
|
||||
f := func(id int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
|
||||
jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id))
|
||||
|
||||
@ -191,13 +220,14 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
|
||||
err := utils.NewMust().
|
||||
DoAny(func() error { return os.Remove(jdgFile) }).
|
||||
Do(func() error { return file.TouchErr(jdgFile) }).
|
||||
Do(func() error { return s.ContainerRun(args) }).
|
||||
Done()
|
||||
|
||||
if err != nil {
|
||||
s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta))
|
||||
s.log.Info("[judge] judge prepare failed", zap.Error(err), zap.Any("meta", *meta))
|
||||
return nil, err
|
||||
}
|
||||
return err
|
||||
|
||||
return s.ContainerRun(args)
|
||||
}
|
||||
}(task.Id)
|
||||
|
||||
@ -212,13 +242,13 @@ func (s *service) ProblemJudge(meta *JudgeMeta) {
|
||||
|
||||
func (s *service) RunAndJudge(meta *JudgeMeta) (*JudgeStatus, int32, e.Status) {
|
||||
// 1. run user program
|
||||
s.ProblemRun(meta)
|
||||
results := s.ProblemRun(meta)
|
||||
|
||||
// 2. run judge
|
||||
s.ProblemJudge(meta)
|
||||
|
||||
// 3. check result
|
||||
result, pts := s.CheckResults(meta)
|
||||
// 3. final JudgeStatus
|
||||
status, pts := s.CheckResults(meta, results)
|
||||
|
||||
return result, pts, e.Success
|
||||
return status, pts, e.Success
|
||||
}
|
||||
|
@ -31,14 +31,18 @@ type TestLibReport struct {
|
||||
Result string `xml:",chardata"`
|
||||
}
|
||||
|
||||
type RuntimeStatus struct {
|
||||
RealTime int `json:"real_time"` // in ms
|
||||
CpuTime int `json:"cpu_time"` // in ms
|
||||
Memory int `json:"memory"` // in kb
|
||||
}
|
||||
|
||||
type TaskStatus struct {
|
||||
Id int `json:"id"`
|
||||
Points int32 `json:"points"`
|
||||
RealTime int `json:"real_time"`
|
||||
CpuTime int `json:"cpu_time"`
|
||||
Memory int `json:"memory"`
|
||||
Verdict int `json:"verdict"`
|
||||
Message string `json:"message"`
|
||||
Id int `json:"id"`
|
||||
Points int32 `json:"points"`
|
||||
Runtime RuntimeStatus `json:"runtime"`
|
||||
Verdict int `json:"verdict"`
|
||||
Message string `json:"message"`
|
||||
|
||||
infoText []byte
|
||||
info map[string]interface{}
|
||||
@ -52,7 +56,7 @@ type JudgeStatus struct {
|
||||
Tasks []TaskStatus `json:"tasks"`
|
||||
}
|
||||
|
||||
func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus {
|
||||
func (t *TaskStatus) ReadSandboxInfo(infoFile string) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
@ -67,7 +71,7 @@ func (t *TaskStatus) getInfoText(infoFile string) *TaskStatus {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) getInfo() *TaskStatus {
|
||||
func (t *TaskStatus) ExtractSandboxInfo() *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
@ -77,20 +81,62 @@ func (t *TaskStatus) getInfo() *TaskStatus {
|
||||
t.Verdict = VerdictSystemError
|
||||
t.Message = "cannot parse info file"
|
||||
} else {
|
||||
t.RealTime = int(t.info["real_time"].(float64))
|
||||
t.CpuTime = int(t.info["cpu_time"].(float64))
|
||||
t.Memory = int(t.info["memory"].(float64))
|
||||
t.Runtime = RuntimeStatus{
|
||||
RealTime: t.info["real_time"].(int),
|
||||
CpuTime: t.info["cpu_time"].(int),
|
||||
Memory: t.info["memory"].(int),
|
||||
}
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) checkExit() *TaskStatus {
|
||||
func (t *TaskStatus) MergeContainerInfo(status *RuntimeStatus) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
if t.info["status"] != "exited" || t.info["code"] != 0.0 {
|
||||
t.Runtime.RealTime = max(t.Runtime.RealTime, status.RealTime)
|
||||
t.Runtime.CpuTime = max(t.Runtime.CpuTime, status.CpuTime)
|
||||
t.Runtime.Memory = max(t.Runtime.Memory, status.Memory)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) CheckTime(cLang *ConfigLanguage) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
if t.Runtime.RealTime > cLang.Runtime.Run.TimeLimit+5 ||
|
||||
t.Runtime.CpuTime > cLang.Runtime.Run.TimeLimit+5 {
|
||||
t.Verdict = VerdictTimeLimitExceeded
|
||||
t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.Runtime.RealTime, t.Runtime.CpuTime)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) CheckMemory(cLang *ConfigLanguage) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
// t.Runtime.Memory is in kb
|
||||
if t.Runtime.Memory > (cLang.Runtime.Run.MemoryLimit+1)*1024 {
|
||||
t.Verdict = VerdictMemoryLimitExceeded
|
||||
t.Message = fmt.Sprintf("memory: %v", t.Runtime.Memory)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) CheckExitCode() *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
if t.info["status"] != "exited" || t.info["code"] != 0 {
|
||||
t.Verdict = VerdictRuntimeError
|
||||
t.Message = fmt.Sprintf("status: %v, code: %v", t.info["status"], t.info["code"])
|
||||
}
|
||||
@ -98,33 +144,7 @@ func (t *TaskStatus) checkExit() *TaskStatus {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) checkTime(cLang *ConfigLanguage) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
if t.info["real_time"].(float64) > float64(cLang.Runtime.Run.TimeLimit)+5 {
|
||||
t.Verdict = VerdictTimeLimitExceeded
|
||||
t.Message = fmt.Sprintf("real_time: %v cpu_time: %v", t.info["real_time"], t.info["cpu_time"])
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) checkMemory(cLang *ConfigLanguage) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
|
||||
if t.info["memory"].(float64) > float64((cLang.Runtime.Run.MemoryLimit+1)*1024) {
|
||||
t.Verdict = VerdictMemoryLimitExceeded
|
||||
t.Message = fmt.Sprintf("memory: %v", t.info["memory"])
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
|
||||
func (t *TaskStatus) ReadJudgeReport(judgeFile string) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
@ -132,7 +152,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
|
||||
j, err := file.Read(judgeFile)
|
||||
if err != nil {
|
||||
t.Verdict = VerdictSystemError
|
||||
t.Message = "cannot read judge file"
|
||||
t.Message = "cannot read judge report"
|
||||
} else {
|
||||
t.judgeText = string(j)
|
||||
}
|
||||
@ -140,7 +160,7 @@ func (t *TaskStatus) getJudgeText(judgeFile string) *TaskStatus {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) getJudge() *TaskStatus {
|
||||
func (t *TaskStatus) DecodeJudgeReport() *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
@ -159,13 +179,13 @@ func (t *TaskStatus) getJudge() *TaskStatus {
|
||||
err := d.Decode(&t.judge)
|
||||
if err != nil {
|
||||
t.Verdict = VerdictSystemError
|
||||
t.Message = "cannot parse judge file"
|
||||
t.Message = "cannot parse judge report"
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus {
|
||||
func (t *TaskStatus) CheckJudgeReport(pts *map[int]int32) *TaskStatus {
|
||||
if t.Verdict != VerdictAccepted {
|
||||
return t
|
||||
}
|
||||
@ -194,7 +214,7 @@ func (t *TaskStatus) checkJudge(pts *map[int]int32) *TaskStatus {
|
||||
return t
|
||||
}
|
||||
|
||||
func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) {
|
||||
func (s *service) CheckResults(meta *JudgeMeta, prResults ProblemRunResults) (*JudgeStatus, int32) {
|
||||
// CE will be processed in phase compile
|
||||
|
||||
pts := map[int]int32{}
|
||||
@ -212,14 +232,15 @@ func (s *service) CheckResults(meta *JudgeMeta) (*JudgeStatus, int32) {
|
||||
info := filepath.Join(dir, fmt.Sprintf("%d.info", i))
|
||||
judge := filepath.Join(dir, fmt.Sprintf("%d.judge", i))
|
||||
|
||||
result.getInfoText(info).
|
||||
getInfo().
|
||||
checkTime(meta.Cfg.Lang).
|
||||
checkMemory(meta.Cfg.Lang).
|
||||
checkExit().
|
||||
getJudgeText(judge).
|
||||
getJudge().
|
||||
checkJudge(&pts)
|
||||
result.ReadSandboxInfo(info).
|
||||
ExtractSandboxInfo().
|
||||
MergeContainerInfo(&prResults[i].Status).
|
||||
CheckTime(meta.Cfg.Lang).
|
||||
CheckMemory(meta.Cfg.Lang).
|
||||
CheckExitCode().
|
||||
ReadJudgeReport(judge).
|
||||
DecodeJudgeReport().
|
||||
CheckJudgeReport(&pts)
|
||||
|
||||
sum += result.Points
|
||||
results = append(results, result)
|
||||
|
@ -2,6 +2,7 @@ package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type TaskPool struct {
|
||||
@ -9,9 +10,8 @@ type TaskPool struct {
|
||||
queue chan Task
|
||||
wg sync.WaitGroup
|
||||
|
||||
lck sync.Mutex
|
||||
curTaskID int
|
||||
waitMap map[int]chan error
|
||||
curTaskID atomic.Uint64
|
||||
waitMap sync.Map
|
||||
}
|
||||
|
||||
type ErrTaskNotFound struct{}
|
||||
@ -20,13 +20,19 @@ func (m *ErrTaskNotFound) Error() string {
|
||||
return "task not found"
|
||||
}
|
||||
|
||||
type WaitBuf struct {
|
||||
Value interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
|
||||
return &TaskPool{
|
||||
tp := &TaskPool{
|
||||
workers: maxWorkers,
|
||||
queue: make(chan Task, bufferSize),
|
||||
waitMap: make(map[int]chan error),
|
||||
curTaskID: 1, // task id starts from 1
|
||||
waitMap: sync.Map{},
|
||||
curTaskID: atomic.Uint64{},
|
||||
}
|
||||
return tp
|
||||
}
|
||||
|
||||
func (tp *TaskPool) Start() {
|
||||
@ -37,16 +43,11 @@ func (tp *TaskPool) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TaskPool) AddTask(f func() error) int {
|
||||
tp.lck.Lock()
|
||||
func (tp *TaskPool) AddTask(f func() (interface{}, error)) uint64 {
|
||||
id := tp.curTaskID.Add(1)
|
||||
|
||||
id := tp.curTaskID
|
||||
tp.curTaskID++
|
||||
|
||||
waitChan := make(chan error, 1)
|
||||
tp.waitMap[id] = waitChan
|
||||
|
||||
tp.lck.Unlock()
|
||||
waitChan := make(chan WaitBuf, 1)
|
||||
tp.waitMap.Store(id, waitChan)
|
||||
|
||||
task := Task{id: id, f: f}
|
||||
tp.queue <- task
|
||||
@ -54,35 +55,29 @@ func (tp *TaskPool) AddTask(f func() error) int {
|
||||
return id
|
||||
}
|
||||
|
||||
func (tp *TaskPool) WaitForTask(taskID int) error {
|
||||
tp.lck.Lock()
|
||||
waitChan, ok := tp.waitMap[taskID]
|
||||
func (tp *TaskPool) WaitForTask(taskID uint64) WaitBuf {
|
||||
val, ok := tp.waitMap.Load(taskID)
|
||||
if !ok {
|
||||
tp.lck.Unlock()
|
||||
return &ErrTaskNotFound{}
|
||||
return WaitBuf{nil, &ErrTaskNotFound{}}
|
||||
}
|
||||
tp.lck.Unlock()
|
||||
waitChan := val.(chan WaitBuf)
|
||||
|
||||
ret := <-waitChan
|
||||
close(waitChan)
|
||||
|
||||
tp.lck.Lock()
|
||||
delete(tp.waitMap, taskID)
|
||||
tp.lck.Unlock()
|
||||
tp.waitMap.Delete(taskID)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (tp *TaskPool) markTaskComplete(taskID int, err error) {
|
||||
tp.lck.Lock()
|
||||
waitChan, ok := tp.waitMap[taskID]
|
||||
func (tp *TaskPool) markTaskComplete(taskID uint64, buf WaitBuf) {
|
||||
val, ok := tp.waitMap.Load(taskID)
|
||||
if !ok {
|
||||
// should never happen here
|
||||
panic("worker: task destroyed before completion")
|
||||
}
|
||||
tp.lck.Unlock()
|
||||
|
||||
waitChan <- err
|
||||
waitChan := val.(chan WaitBuf)
|
||||
waitChan <- buf
|
||||
}
|
||||
|
||||
func (tp *TaskPool) Stop() {
|
||||
|
@ -2,7 +2,6 @@ package pool
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -16,8 +15,8 @@ func TestTaskPool_Stop(t *testing.T) {
|
||||
counter := 0
|
||||
|
||||
for i := 1; i <= 10; i++ {
|
||||
f := func(i int) func() error {
|
||||
return func() error {
|
||||
f := func(i int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
lck.Lock()
|
||||
t.Log("task", i, "locked")
|
||||
counter += i
|
||||
@ -27,7 +26,7 @@ func TestTaskPool_Stop(t *testing.T) {
|
||||
time.Sleep(time.Duration(i*100) * time.Millisecond)
|
||||
t.Log("task", i, "finished")
|
||||
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}(i)
|
||||
pool.AddTask(f)
|
||||
@ -47,12 +46,12 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
||||
counter := 0
|
||||
|
||||
for i := 1; i <= 10; i++ {
|
||||
f := func(i int) func() error {
|
||||
return func() error {
|
||||
f := func(i int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
counter += 1
|
||||
t.Log("task", i, "finished")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return errors.New(strconv.Itoa(i))
|
||||
return i, nil
|
||||
}
|
||||
}(i)
|
||||
id := pool.AddTask(f)
|
||||
@ -61,8 +60,11 @@ func TestTaskPool_WaitForTask(t *testing.T) {
|
||||
if counter != 1 {
|
||||
t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id)
|
||||
}
|
||||
if ret.Error() != strconv.Itoa(i) {
|
||||
t.Errorf("Return value mismatch: expected %s, got %s, task %d", strconv.Itoa(i), ret.Error(), id)
|
||||
if ret.Error != nil {
|
||||
t.Errorf("Return error: %v, task %d", ret.Error, id)
|
||||
}
|
||||
if ret.Value.(int) != i {
|
||||
t.Errorf("Return value mismatch: expected %d, got %v, task %d", i, ret, id)
|
||||
}
|
||||
counter -= 1
|
||||
}
|
||||
@ -74,21 +76,21 @@ func TestTaskPool_DoubleWait(t *testing.T) {
|
||||
pool := NewTaskPool(1, 1)
|
||||
pool.Start()
|
||||
|
||||
f := func() error {
|
||||
f := func() (interface{}, error) {
|
||||
t.Log("task invoked")
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
id := pool.AddTask(f)
|
||||
|
||||
ret := pool.WaitForTask(id)
|
||||
if ret != nil {
|
||||
if ret.Error != nil {
|
||||
t.Errorf("task returned error: %v", ret)
|
||||
}
|
||||
|
||||
ret2 := pool.WaitForTask(id)
|
||||
if ret2 == nil {
|
||||
if ret2.Error == nil {
|
||||
t.Errorf("2nd wait returned nil")
|
||||
} else if !errors.Is(ret2, &ErrTaskNotFound{}) {
|
||||
} else if !errors.Is(ret2.Error, &ErrTaskNotFound{}) {
|
||||
t.Errorf("2nd wait returned wrong error: %v", ret2)
|
||||
}
|
||||
|
||||
@ -102,10 +104,10 @@ func TestTaskPool_One(t *testing.T) {
|
||||
lck := sync.Mutex{}
|
||||
counter := 0
|
||||
|
||||
ids := make([]int, 0)
|
||||
ids := make([]uint64, 0)
|
||||
for i := 1; i <= 10; i++ {
|
||||
f := func(i int) func() error {
|
||||
return func() error {
|
||||
f := func(i int) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
lck.Lock()
|
||||
t.Log("task", i, "locked")
|
||||
counter += i
|
||||
@ -115,7 +117,7 @@ func TestTaskPool_One(t *testing.T) {
|
||||
time.Sleep(time.Duration(i*10) * time.Millisecond)
|
||||
t.Log("task", i, "finished")
|
||||
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}(i)
|
||||
id := pool.AddTask(f)
|
||||
|
@ -1,6 +1,6 @@
|
||||
package pool
|
||||
|
||||
type Task struct {
|
||||
id int
|
||||
f func() error
|
||||
id uint64
|
||||
f func() (interface{}, error)
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func (w *Worker) Start(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
for task := range w.queue {
|
||||
err := task.f()
|
||||
w.pool.markTaskComplete(task.id, err)
|
||||
val, err := task.f()
|
||||
w.pool.markTaskComplete(task.id, WaitBuf{Value: val, Error: err})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user