140 lines
2.9 KiB
Go
140 lines
2.9 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"git.0x7f.app/WOJ/woj-server/internal/misc/config"
|
|
"git.0x7f.app/WOJ/woj-server/internal/misc/log"
|
|
"git.0x7f.app/WOJ/woj-server/internal/model"
|
|
"github.com/samber/do"
|
|
"go.uber.org/zap"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/schema"
|
|
"moul.io/zapgorm2"
|
|
"time"
|
|
)
|
|
|
|
var _ Service = (*service)(nil)
|
|
|
|
type Service interface {
|
|
Get() *gorm.DB
|
|
Close() error
|
|
HealthCheck() error
|
|
}
|
|
|
|
func NewService(i *do.Injector) (Service, error) {
|
|
srv := &service{}
|
|
srv.log = do.MustInvoke[log.Service](i).GetLogger("postgresql")
|
|
|
|
conf := do.MustInvoke[config.Service](i).GetConfig()
|
|
srv.setup(conf)
|
|
return srv, srv.err
|
|
}
|
|
|
|
type service struct {
|
|
log *zap.Logger
|
|
db *gorm.DB
|
|
err error
|
|
}
|
|
|
|
func (s *service) Get() *gorm.DB {
|
|
return s.db
|
|
}
|
|
|
|
func (s *service) Close() error {
|
|
var db *sql.DB
|
|
db, s.err = s.db.DB()
|
|
if s.err != nil {
|
|
return s.err
|
|
}
|
|
|
|
s.err = db.Close()
|
|
return s.err
|
|
}
|
|
|
|
func (s *service) HealthCheck() error {
|
|
return s.err
|
|
}
|
|
|
|
func (s *service) setup(conf *model.Config) {
|
|
s.log.Info("Connecting to database...")
|
|
|
|
logger := zapgorm2.New(s.log)
|
|
logger.IgnoreRecordNotFoundError = true
|
|
|
|
dsn := fmt.Sprintf(
|
|
// TODO: timezone as config
|
|
"user=%s password=%s dbname=%s host=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
|
|
conf.Database.User,
|
|
conf.Database.Password,
|
|
conf.Database.Database,
|
|
conf.Database.Host,
|
|
conf.Database.Port,
|
|
)
|
|
|
|
s.db, s.err = gorm.Open(
|
|
postgres.Open(dsn),
|
|
&gorm.Config{
|
|
NamingStrategy: schema.NamingStrategy{
|
|
SingularTable: true,
|
|
TablePrefix: conf.Database.Prefix,
|
|
},
|
|
PrepareStmt: true,
|
|
Logger: logger,
|
|
},
|
|
)
|
|
|
|
if s.err != nil {
|
|
s.log.Error("Failed to connect to database", zap.Error(s.err))
|
|
return
|
|
}
|
|
|
|
var db *sql.DB
|
|
db, s.err = s.checkAlive(3)
|
|
if s.err != nil {
|
|
s.log.Error("Database is not alive", zap.Error(s.err))
|
|
return
|
|
}
|
|
|
|
db.SetMaxOpenConns(conf.Database.MaxOpenConns)
|
|
db.SetMaxIdleConns(conf.Database.MaxIdleConns)
|
|
db.SetConnMaxLifetime(time.Duration(conf.Database.ConnMaxLifetime) * time.Minute)
|
|
|
|
s.migrateDatabase()
|
|
}
|
|
|
|
func (s *service) migrateDatabase() {
|
|
s.log.Info("Auto Migrating database...")
|
|
|
|
_ = s.db.AutoMigrate(&model.User{})
|
|
_ = s.db.AutoMigrate(&model.Problem{})
|
|
_ = s.db.AutoMigrate(&model.ProblemVersion{})
|
|
_ = s.db.AutoMigrate(&model.Submission{})
|
|
_ = s.db.AutoMigrate(&model.Status{})
|
|
}
|
|
|
|
func (s *service) checkAlive(retry int) (*sql.DB, error) {
|
|
if retry <= 0 {
|
|
return nil, errors.New("all retries are used up. failed to connect to database")
|
|
}
|
|
|
|
db, err := s.db.DB()
|
|
if err != nil {
|
|
s.log.Warn("failed to get sql.DB instance", zap.Error(err))
|
|
time.Sleep(5 * time.Second)
|
|
return s.checkAlive(retry - 1)
|
|
}
|
|
|
|
err = db.Ping()
|
|
if err != nil {
|
|
s.log.Warn("failed to ping database", zap.Error(err))
|
|
time.Sleep(5 * time.Second)
|
|
return s.checkAlive(retry - 1)
|
|
}
|
|
|
|
s.log.Info("database connect established")
|
|
return db, nil
|
|
}
|