feat: compile and new_problem are queued by pool (#7)

This commit is contained in:
Paul Pan 2024-01-28 18:16:04 +08:00
parent 8429988cb4
commit 6965435cb7
Signed by: Paul
GPG Key ID: D639BDF5BA578AF4
8 changed files with 63 additions and 38 deletions

View File

@ -97,7 +97,8 @@ func (s *service) Compile(meta *JudgeMeta) (*JudgeStatus, e.Status) {
}, },
} }
return s.ContainerRun(args) id := s.ContainerRunPool(args)
return s.pool.WaitForTask(id)
}). }).
Done() Done()

View File

@ -136,3 +136,7 @@ func (s *service) ContainerRun(arg *RunArgs) error {
return nil return nil
} }
func (s *service) ContainerRunPool(arg *RunArgs) int {
return s.pool.AddTask(func() error { return s.ContainerRun(arg) })
}

View File

@ -79,7 +79,8 @@ func (s *service) PrebuildProblem(meta *JudgeMeta, config *Config, force bool) e
}, },
} }
err := s.ContainerRun(args) id := s.ContainerRunPool(args)
err := s.pool.WaitForTask(id)
if err != nil { if err != nil {
s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Version)) s.log.Warn("[new] prebuild problem failed", zap.Error(err), zap.Uint("version", meta.Version))

View File

@ -46,8 +46,8 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu
ids := make([]int, 0) ids := make([]int, 0)
for _, task := range config.Tasks { for _, task := range config.Tasks {
f := func(id int) func() { f := func(id int) func() error {
return func() { return func() error {
testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id)) testCase := filepath.Join(dataDir, fmt.Sprintf("%d.input", id))
targetFile := filepath.Join(workDir, fmt.Sprintf("%s.out", meta.User)) targetFile := filepath.Join(workDir, fmt.Sprintf("%s.out", meta.User))
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
@ -101,6 +101,7 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu
if err != nil { if err != nil {
s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta)) s.log.Info("[run] run failed", zap.Error(err), zap.Any("meta", *meta))
} }
return err
} }
}(task.Id) }(task.Id)
@ -109,7 +110,7 @@ func (s *service) ProblemRun(meta *JudgeMeta, config *Config, cLang *ConfigLangu
} }
for _, id := range ids { for _, id := range ids {
s.pool.WaitForTask(id) _ = s.pool.WaitForTask(id)
} }
} }
@ -128,8 +129,8 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan
ids := make([]int, 0) ids := make([]int, 0)
for _, task := range config.Tasks { for _, task := range config.Tasks {
f := func(id int) func() { f := func(id int) func() error {
return func() { return func() error {
ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id)) ansFile := filepath.Join(workDir, fmt.Sprintf("%d.out.usr", id))
jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id)) jdgFile := filepath.Join(workDir, fmt.Sprintf("%d.judge", id))
@ -182,6 +183,7 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan
if err != nil { if err != nil {
s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta)) s.log.Info("[judge] judge failed", zap.Error(err), zap.Any("meta", *meta))
} }
return err
} }
}(task.Id) }(task.Id)
@ -190,7 +192,7 @@ func (s *service) ProblemJudge(meta *JudgeMeta, config *Config, cLang *ConfigLan
} }
for _, id := range ids { for _, id := range ids {
s.pool.WaitForTask(id) _ = s.pool.WaitForTask(id)
} }
} }

View File

@ -1,6 +1,7 @@
package pool package pool
import ( import (
"errors"
"sync" "sync"
) )
@ -11,14 +12,14 @@ type TaskPool struct {
lck sync.Mutex lck sync.Mutex
curTaskID int curTaskID int
waitMap map[int]chan struct{} waitMap map[int]chan error
} }
func NewTaskPool(maxWorkers, bufferSize int) *TaskPool { func NewTaskPool(maxWorkers, bufferSize int) *TaskPool {
return &TaskPool{ return &TaskPool{
workers: maxWorkers, workers: maxWorkers,
queue: make(chan Task, bufferSize), queue: make(chan Task, bufferSize),
waitMap: make(map[int]chan struct{}), waitMap: make(map[int]chan error),
curTaskID: 1, // task id starts from 1 curTaskID: 1, // task id starts from 1
} }
} }
@ -31,13 +32,13 @@ func (tp *TaskPool) Start() {
} }
} }
func (tp *TaskPool) AddTask(f func()) int { func (tp *TaskPool) AddTask(f func() error) int {
tp.lck.Lock() tp.lck.Lock()
id := tp.curTaskID id := tp.curTaskID
tp.curTaskID++ tp.curTaskID++
waitChan := make(chan struct{}) waitChan := make(chan error, 1)
tp.waitMap[id] = waitChan tp.waitMap[id] = waitChan
tp.lck.Unlock() tp.lck.Unlock()
@ -48,7 +49,26 @@ func (tp *TaskPool) AddTask(f func()) int {
return id return id
} }
func (tp *TaskPool) WaitForTask(taskID int) { func (tp *TaskPool) WaitForTask(taskID int) error {
tp.lck.Lock()
waitChan, ok := tp.waitMap[taskID]
if !ok {
tp.lck.Unlock()
return errors.New("task not found")
}
tp.lck.Unlock()
ret := <-waitChan
close(waitChan)
tp.lck.Lock()
delete(tp.waitMap, taskID)
tp.lck.Unlock()
return ret
}
func (tp *TaskPool) markTaskComplete(taskID int, err error) {
tp.lck.Lock() tp.lck.Lock()
waitChan, ok := tp.waitMap[taskID] waitChan, ok := tp.waitMap[taskID]
if !ok { if !ok {
@ -57,20 +77,7 @@ func (tp *TaskPool) WaitForTask(taskID int) {
} }
tp.lck.Unlock() tp.lck.Unlock()
<-waitChan waitChan <- err
}
func (tp *TaskPool) markTaskComplete(taskID int) {
tp.lck.Lock()
defer tp.lck.Unlock()
waitChan, ok := tp.waitMap[taskID]
if !ok {
return
}
close(waitChan)
delete(tp.waitMap, taskID)
} }
func (tp *TaskPool) Stop() { func (tp *TaskPool) Stop() {

View File

@ -1,6 +1,8 @@
package pool package pool
import ( import (
"errors"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -14,8 +16,8 @@ func TestTaskPool_Stop(t *testing.T) {
counter := 0 counter := 0
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() { f := func(i int) func() error {
return func() { return func() error {
lck.Lock() lck.Lock()
t.Log("task", i, "locked") t.Log("task", i, "locked")
counter += i counter += i
@ -24,6 +26,8 @@ func TestTaskPool_Stop(t *testing.T) {
time.Sleep(time.Duration(i*100) * time.Millisecond) time.Sleep(time.Duration(i*100) * time.Millisecond)
t.Log("task", i, "finished") t.Log("task", i, "finished")
return nil
} }
}(i) }(i)
pool.AddTask(f) pool.AddTask(f)
@ -43,18 +47,22 @@ func TestTaskPool_WaitForTask(t *testing.T) {
counter := 0 counter := 0
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() { f := func(i int) func() error {
return func() { return func() error {
counter += 1 counter += 1
t.Log("task", i, "finished") t.Log("task", i, "finished")
return errors.New(strconv.Itoa(i))
} }
}(i) }(i)
id := pool.AddTask(f) id := pool.AddTask(f)
pool.WaitForTask(id) ret := pool.WaitForTask(id)
if counter != 1 { if counter != 1 {
t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id) t.Errorf("Counter mismatch: expected %d, got %d, task %d", 1, counter, id)
} }
if ret.Error() != strconv.Itoa(i) {
t.Errorf("Return value mismatch: expected %s, got %s, task %d", strconv.Itoa(i), ret.Error(), id)
}
counter -= 1 counter -= 1
} }
@ -70,8 +78,8 @@ func TestTaskPool_One(t *testing.T) {
ids := make([]int, 0) ids := make([]int, 0)
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
f := func(i int) func() { f := func(i int) func() error {
return func() { return func() error {
lck.Lock() lck.Lock()
t.Log("task", i, "locked") t.Log("task", i, "locked")
counter += i counter += i
@ -80,6 +88,8 @@ func TestTaskPool_One(t *testing.T) {
time.Sleep(time.Duration(i*10) * time.Millisecond) time.Sleep(time.Duration(i*10) * time.Millisecond)
t.Log("task", i, "finished") t.Log("task", i, "finished")
return nil
} }
}(i) }(i)
id := pool.AddTask(f) id := pool.AddTask(f)
@ -87,7 +97,7 @@ func TestTaskPool_One(t *testing.T) {
} }
for _, id := range ids { for _, id := range ids {
pool.WaitForTask(id) _ = pool.WaitForTask(id)
} }
if counter != 55 { if counter != 55 {

View File

@ -2,5 +2,5 @@ package pool
type Task struct { type Task struct {
id int id int
f func() f func() error
} }

View File

@ -18,7 +18,7 @@ func (w *Worker) Start(wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
for task := range w.queue { for task := range w.queue {
task.f() err := task.f()
w.pool.markTaskComplete(task.id) w.pool.markTaskComplete(task.id, err)
} }
} }