package db import ( "database/sql" "errors" "fmt" "git.0x7f.app/WOJ/woj-server/internal/misc/config" "git.0x7f.app/WOJ/woj-server/internal/misc/log" "git.0x7f.app/WOJ/woj-server/internal/model" "git.0x7f.app/WOJ/woj-server/pkg/utils" "github.com/samber/do" "go.uber.org/zap" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" "hash/fnv" "moul.io/zapgorm2" "time" ) var _ Service = (*service)(nil) type Service interface { Get() *gorm.DB Migrate() HealthCheck() error Shutdown() error } func NewService(i *do.Injector) (Service, error) { srv := &service{} srv.log = do.MustInvoke[log.Service](i).GetLogger("postgresql") conf := do.MustInvoke[config.Service](i).GetConfig() srv.setup(conf) return srv, srv.err } type service struct { log *zap.Logger db *gorm.DB err error } func (s *service) Get() *gorm.DB { return s.db } func (s *service) Migrate() { s.log.Info("Auto Migrating database...") // Running AutoMigrate concurrently on the same model fails with various race conditions // https://github.com/go-gorm/gorm/pull/6680 // https://github.com/go-gorm/postgres/pull/224 // Obtain a lock to prevent concurrent AutoMigrate lockID := func(s string) int64 { h := fnv.New64a() _, err := h.Write([]byte(s)) return utils.If(err != nil, int64(0x4242AA55), int64(h.Sum64())) }("gorm:migrator") s.err = s.db.Exec("SELECT pg_advisory_lock(?)", lockID).Error if s.err != nil { s.log.Error("Failed to obtain lock", zap.Error(s.err)) return } _ = s.db.AutoMigrate(&model.User{}) _ = s.db.AutoMigrate(&model.Problem{}) _ = s.db.AutoMigrate(&model.ProblemVersion{}) _ = s.db.AutoMigrate(&model.Submission{}) _ = s.db.AutoMigrate(&model.Status{}) s.err = s.db.Exec("SELECT pg_advisory_unlock(?)", lockID).Error if s.err != nil { s.log.Error("Failed to release lock", zap.Error(s.err)) } } func (s *service) HealthCheck() error { return s.err } func (s *service) Shutdown() error { var db *sql.DB db, s.err = s.db.DB() if s.err != nil { return s.err } s.err = db.Close() return s.err } func (s *service) setup(conf *model.Config) { s.log.Info("Connecting to database...") logger := zapgorm2.New(s.log) logger.IgnoreRecordNotFoundError = true tz := utils.If(conf.Database.TimeZone == "", "Asia/Shanghai", conf.Database.TimeZone) dsn := fmt.Sprintf( "user=%s password=%s dbname=%s host=%s port=%d sslmode=disable TimeZone=%s", conf.Database.User, conf.Database.Password, conf.Database.Database, conf.Database.Host, conf.Database.Port, tz, ) s.db, s.err = gorm.Open( postgres.Open(dsn), &gorm.Config{ NamingStrategy: schema.NamingStrategy{ SingularTable: true, TablePrefix: conf.Database.Prefix, }, PrepareStmt: true, Logger: logger, }, ) if s.err != nil { s.log.Error("Failed to connect to database", zap.Error(s.err)) return } var db *sql.DB db, s.err = s.checkAlive(3) if s.err != nil { s.log.Error("Database is not alive", zap.Error(s.err)) return } db.SetMaxOpenConns(conf.Database.MaxOpenConns) db.SetMaxIdleConns(conf.Database.MaxIdleConns) db.SetConnMaxLifetime(time.Duration(conf.Database.ConnMaxLifetime) * time.Minute) } func (s *service) checkAlive(retry int) (*sql.DB, error) { if retry <= 0 { return nil, errors.New("all retries are used up. failed to connect to database") } db, err := s.db.DB() if err != nil { s.log.Warn("failed to get sql.DB instance", zap.Error(err)) time.Sleep(5 * time.Second) return s.checkAlive(retry - 1) } err = db.Ping() if err != nil { s.log.Warn("failed to ping database", zap.Error(err)) time.Sleep(5 * time.Second) return s.checkAlive(retry - 1) } s.log.Info("database connect established") return db, nil }