From 6965435cb786e4a65c6c0c93a01d6806992860b4 Mon Sep 17 00:00:00 2001 From: Paul Pan Date: Sun, 28 Jan 2024 18:16:04 +0800 Subject: [PATCH] feat: `compile` and `new_problem` are queued by pool (#7) --- internal/service/runner/compile.go | 3 +- internal/service/runner/container.go | 4 +++ internal/service/runner/new_problem.go | 3 +- internal/service/runner/run_judge.go | 14 ++++---- pkg/pool/pool.go | 45 +++++++++++++++----------- pkg/pool/pool_test.go | 26 ++++++++++----- pkg/pool/task.go | 2 +- pkg/pool/worker.go | 4 +-- 8 files changed, 63 insertions(+), 38 deletions(-) diff --git a/internal/service/runner/compile.go b/internal/service/runner/compile.go index 5035dbe..cf9662d 100644 --- a/internal/service/runner/compile.go +++ b/internal/service/runner/compile.go @@ -97,7 +97,8 @@ func (s *service) Compile(meta *JudgeMeta) (*JudgeStatus, e.Status) { }, } - return s.ContainerRun(args) + id := s.ContainerRunPool(args) + return s.pool.WaitForTask(id) }). Done() diff --git a/internal/service/runner/container.go b/internal/service/runner/container.go index 3be2367..917350d 100644 --- a/internal/service/runner/container.go +++ b/internal/service/runner/container.go @@ -136,3 +136,7 @@ func (s *service) ContainerRun(arg *RunArgs) error { return nil } + +func (s *service) ContainerRunPool(arg *RunArgs) int { + return s.pool.AddTask(func() error { return s.ContainerRun(arg) }) +} diff --git a/internal/service/runner/new_problem.go b/internal/service/runner/new_problem.go index 94fbb93..47e9d1b 100644 --- a/internal/service/runner/new_problem.go +++ b/internal/service/runner/new_problem.go @@ -79,7 +79,8 @@ func (s *service) PrebuildProblem(meta *JudgeMeta, config *Config, force bool) e }, } - err := s.ContainerRun(args) + id := s.ContainerRunPool(args) + err := s.pool.WaitForTask(id) if err != nil { s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Version)) diff --git a/internal/service/runner/run_judge.go b/internal/service/runner/run_judge.go index 761f4f1..ea826aa 100644 --- a/internal/service/runner/run_judge.go +++ b/internal/service/runner/run_judge.go @@ -46,8 +46,8 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu ids := make([]int, 0) for _, task := range config.Tasks { - f := func(id int) func() { - return func() { + f := func(id int) func() error { + return func() error { testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id)) targetFile := filepath.Join(workDir, fmt.Sprintf("%s.out", meta.User)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) @@ -101,6 +101,7 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu if err != nil { s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta)) } + return err } }(task.Id) @@ -109,7 +110,7 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu } for _, id := range ids { - s.pool.WaitForTask(id) + _ = s.pool.WaitForTask(id) } } @@ -128,8 +129,8 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan ids := make([]int, 0) for _, task := range config.Tasks { - f := func(id int) func() { - return func() { + f := func(id int) func() error { + return func() error { ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id)) @@ -182,6 +183,7 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan if err != nil { s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta)) } + return err } }(task.Id) @@ -190,7 +192,7 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan } for _, id := range ids { - s.pool.WaitForTask(id) + _ = s.pool.WaitForTask(id) } } diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index f1a04ca..4106fdb 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "errors" "sync" ) @@ -11,14 +12,14 @@ type TaskPool struct { lck sync.Mutex curTaskID int - waitMap map[int]chan struct{} + waitMap map[int]chan error } func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { return &TaskPool{ workers: maxWorkers, queue: make(chan Task, bufferSize), - waitMap: make(map[int]chan struct{}), + waitMap: make(map[int]chan error), curTaskID: 1, // task id starts from 1 } } @@ -31,13 +32,13 @@ func (tp *TaskPool) Start() { } } -func (tp *TaskPool) AddTask(f func()) int { +func (tp *TaskPool) AddTask(f func() error) int { tp.lck.Lock() id := tp.curTaskID tp.curTaskID++ - waitChan := make(chan struct{}) + waitChan := make(chan error, 1) tp.waitMap[id] = waitChan tp.lck.Unlock() @@ -48,7 +49,26 @@ func (tp *TaskPool) AddTask(f func()) int { return id } -func (tp *TaskPool) WaitForTask(taskID int) { +func (tp *TaskPool) WaitForTask(taskID int) error { + tp.lck.Lock() + waitChan, ok := tp.waitMap[taskID] + if !ok { + tp.lck.Unlock() + return errors.New("task not found") + } + tp.lck.Unlock() + + ret := <-waitChan + close(waitChan) + + tp.lck.Lock() + delete(tp.waitMap, taskID) + tp.lck.Unlock() + + return ret +} + +func (tp *TaskPool) markTaskComplete(taskID int, err error) { tp.lck.Lock() waitChan, ok := tp.waitMap[taskID] if !ok { @@ -57,20 +77,7 @@ func (tp *TaskPool) WaitForTask(taskID int) { } tp.lck.Unlock() - <-waitChan -} - -func (tp *TaskPool) markTaskComplete(taskID int) { - tp.lck.Lock() - defer tp.lck.Unlock() - - waitChan, ok := tp.waitMap[taskID] - if !ok { - return - } - - close(waitChan) - delete(tp.waitMap, taskID) + waitChan <- err } func (tp *TaskPool) Stop() { diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go index f82680c..a06389e 100644 --- a/pkg/pool/pool_test.go +++ b/pkg/pool/pool_test.go @@ -1,6 +1,8 @@ package pool import ( + "errors" + "strconv" "sync" "testing" "time" @@ -14,8 +16,8 @@ func TestTaskPool_Stop(t *testing.T) { counter := 0 for i := 1; i <= 10; i++ { - f := func(i int) func() { - return func() { + f := func(i int) func() error { + return func() error { lck.Lock() t.Log("task", i, "locked") counter += i @@ -24,6 +26,8 @@ func TestTaskPool_Stop(t *testing.T) { time.Sleep(time.Duration(i*100) * time.Millisecond) t.Log("task", i, "finished") + + return nil } }(i) pool.AddTask(f) @@ -43,18 +47,22 @@ func TestTaskPool_WaitForTask(t *testing.T) { counter := 0 for i := 1; i <= 10; i++ { - f := func(i int) func() { - return func() { + f := func(i int) func() error { + return func() error { counter += 1 t.Log("task", i, "finished") + return errors.New(strconv.Itoa(i)) } }(i) id := pool.AddTask(f) - pool.WaitForTask(id) + ret := pool.WaitForTask(id) if counter != 1 { t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) } + if ret.Error() != strconv.Itoa(i) { + t.Errorf("Return value mismatch: expected %s, got %s, task %d", strconv.Itoa(i), ret.Error(), id) + } counter -= 1 } @@ -70,8 +78,8 @@ func TestTaskPool_One(t *testing.T) { ids := make([]int, 0) for i := 1; i <= 10; i++ { - f := func(i int) func() { - return func() { + f := func(i int) func() error { + return func() error { lck.Lock() t.Log("task", i, "locked") counter += i @@ -80,6 +88,8 @@ func TestTaskPool_One(t *testing.T) { time.Sleep(time.Duration(i*10) * time.Millisecond) t.Log("task", i, "finished") + + return nil } }(i) id := pool.AddTask(f) @@ -87,7 +97,7 @@ func TestTaskPool_One(t *testing.T) { } for _, id := range ids { - pool.WaitForTask(id) + _ = pool.WaitForTask(id) } if counter != 55 { diff --git a/pkg/pool/task.go b/pkg/pool/task.go index d3f8b7c..7efa300 100644 --- a/pkg/pool/task.go +++ b/pkg/pool/task.go @@ -2,5 +2,5 @@ package pool type Task struct { id int - f func() + f func() error } diff --git a/pkg/pool/worker.go b/pkg/pool/worker.go index 8cd5473..55e8fef 100644 --- a/pkg/pool/worker.go +++ b/pkg/pool/worker.go @@ -18,7 +18,7 @@ func (w *Worker) Start(wg *sync.WaitGroup) { defer wg.Done() for task := range w.queue { - task.f() - w.pool.markTaskComplete(task.id) + err := task.f() + w.pool.markTaskComplete(task.id, err) } }