woj-server/internal/repo/db/pg.go

165 lines
3.6 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"git.0x7f.app/WOJ/woj-server/internal/misc/config"
"git.0x7f.app/WOJ/woj-server/internal/misc/log"
"git.0x7f.app/WOJ/woj-server/internal/model"
"git.0x7f.app/WOJ/woj-server/pkg/utils"
"github.com/samber/do"
"go.uber.org/zap"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"hash/fnv"
"moul.io/zapgorm2"
"time"
)
var _ Service = (*service)(nil)
type Service interface {
Get() *gorm.DB
Close() error
HealthCheck() error
}
func NewService(i *do.Injector) (Service, error) {
srv := &service{}
srv.log = do.MustInvoke[log.Service](i).GetLogger("postgresql")
conf := do.MustInvoke[config.Service](i).GetConfig()
srv.setup(conf)
return srv, srv.err
}
type service struct {
log *zap.Logger
db *gorm.DB
err error
}
func (s *service) Get() *gorm.DB {
return s.db
}
func (s *service) Close() error {
var db *sql.DB
db, s.err = s.db.DB()
if s.err != nil {
return s.err
}
s.err = db.Close()
return s.err
}
func (s *service) HealthCheck() error {
return s.err
}
func (s *service) setup(conf *model.Config) {
s.log.Info("Connecting to database...")
logger := zapgorm2.New(s.log)
logger.IgnoreRecordNotFoundError = true
dsn := fmt.Sprintf(
// TODO: timezone as config
"user=%s password=%s dbname=%s host=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
conf.Database.User,
conf.Database.Password,
conf.Database.Database,
conf.Database.Host,
conf.Database.Port,
)
s.db, s.err = gorm.Open(
postgres.Open(dsn),
&gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
TablePrefix: conf.Database.Prefix,
},
PrepareStmt: true,
Logger: logger,
},
)
if s.err != nil {
s.log.Error("Failed to connect to database", zap.Error(s.err))
return
}
var db *sql.DB
db, s.err = s.checkAlive(3)
if s.err != nil {
s.log.Error("Database is not alive", zap.Error(s.err))
return
}
db.SetMaxOpenConns(conf.Database.MaxOpenConns)
db.SetMaxIdleConns(conf.Database.MaxIdleConns)
db.SetConnMaxLifetime(time.Duration(conf.Database.ConnMaxLifetime) * time.Minute)
s.migrateDatabase()
}
func (s *service) migrateDatabase() {
s.log.Info("Auto Migrating database...")
// Running AutoMigrate concurrently on the same model fails with various race conditions
// https://github.com/go-gorm/gorm/pull/6680
// https://github.com/go-gorm/postgres/pull/224
// Obtain a lock to prevent concurrent AutoMigrate
lockID := func(s string) int64 {
h := fnv.New64a()
_, err := h.Write([]byte(s))
return utils.If(err != nil, int64(0x4242AA55), int64(h.Sum64()))
}("gorm:migrator")
s.err = s.db.Exec("SELECT pg_advisory_lock(?)", lockID).Error
if s.err != nil {
s.log.Error("Failed to obtain lock", zap.Error(s.err))
return
}
_ = s.db.AutoMigrate(&model.User{})
_ = s.db.AutoMigrate(&model.Problem{})
_ = s.db.AutoMigrate(&model.ProblemVersion{})
_ = s.db.AutoMigrate(&model.Submission{})
_ = s.db.AutoMigrate(&model.Status{})
s.err = s.db.Exec("SELECT pg_advisory_unlock(?)", lockID).Error
if s.err != nil {
s.log.Error("Failed to release lock", zap.Error(s.err))
}
}
func (s *service) checkAlive(retry int) (*sql.DB, error) {
if retry <= 0 {
return nil, errors.New("all retries are used up. failed to connect to database")
}
db, err := s.db.DB()
if err != nil {
s.log.Warn("failed to get sql.DB instance", zap.Error(err))
time.Sleep(5 * time.Second)
return s.checkAlive(retry - 1)
}
err = db.Ping()
if err != nil {
s.log.Warn("failed to ping database", zap.Error(err))
time.Sleep(5 * time.Second)
return s.checkAlive(retry - 1)
}
s.log.Info("database connect established")
return db, nil
}