feat: always use sso for user
This commit is contained in:
parent
aaac7f57a0
commit
6cbb1ab4a9
@ -7,6 +7,7 @@ import (
|
||||
"git.0x7f.app/WOJ/woj-server/internal/model"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/service/user"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jackc/pgtype"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -15,13 +16,17 @@ import (
|
||||
// @Description Callback endpoint from OAuth2
|
||||
// @Tags oauth
|
||||
// @Produce json
|
||||
// @Router /oauth/callback [get]
|
||||
// @Router /v1/oauth/callback [get]
|
||||
func (h *handler) CallbackHandler() gin.HandlerFunc {
|
||||
// TODO: Figure out a better way to cooperate with frontend
|
||||
// Currently using /login?redirect_token=xxx to pass jwt token
|
||||
// /error?message=xxx to pass error message
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Extract key from cookie
|
||||
key, err := c.Cookie(oauthStateCookieName)
|
||||
if err != nil {
|
||||
e.Pong[any](c, e.InvalidParameter, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.InvalidParameter.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
@ -29,7 +34,7 @@ func (h *handler) CallbackHandler() gin.HandlerFunc {
|
||||
key = fmt.Sprintf(oauthStateKey, key)
|
||||
expected, err := h.cache.Get().Get(context.Background(), key).Result()
|
||||
if err != nil {
|
||||
e.Pong[any](c, e.RedisError, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.RedisError.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
@ -38,59 +43,81 @@ func (h *handler) CallbackHandler() gin.HandlerFunc {
|
||||
|
||||
// Verify state
|
||||
if c.Query("state") != expected {
|
||||
e.Pong[any](c, e.OAuthStateMismatch, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.OAuthStateMismatch.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := h.conf.Exchange(context.Background(), c.Query("code"))
|
||||
if err != nil {
|
||||
e.Pong[any](c, e.OAuthExchangeFailed, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
raw, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
e.Pong[any](c, e.OAuthExchangeFailed, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := h.verifier.Verify(context.Background(), raw)
|
||||
if err != nil {
|
||||
e.Pong[any](c, e.OAuthVerifyFailed, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.OAuthVerifyFailed.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Extract custom claims
|
||||
// TODO: extract role from claims
|
||||
// TODO: extract role from claims: need to modify oidc provider
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Nickname string `json:"preferred_username"`
|
||||
Role string `json:"role"`
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
|
||||
// Exp uint64 `json:"exp"`
|
||||
// Iat uint64 `json:"iat"`
|
||||
// AuthTime uint64 `json:"auth_time"`
|
||||
// Jti string `json:"jti"`
|
||||
// Iss string `json:"iss"`
|
||||
// Aud string `json:"aud"`
|
||||
// Typ string `json:"typ"`
|
||||
// Azp string `json:"azp"`
|
||||
// SessionState string `json:"session_state"`
|
||||
// AtHash string `json:"at_hash"`
|
||||
// Acr string `json:"acr"`
|
||||
// Sid string `json:"sid"`
|
||||
// PreferredUsername string `json:"preferred_username"`
|
||||
// GivenName string `json:"given_name"`
|
||||
// FamilyName string `json:"family_name"`
|
||||
}
|
||||
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
e.Pong[any](c, e.OAuthGetClaimsFailed, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.OAuthGetClaimsFailed.QueryString())
|
||||
return
|
||||
}
|
||||
if !claims.EmailVerified || claims.Email == "" || claims.Nickname == "" {
|
||||
e.Pong[any](c, e.UserInvalid, nil)
|
||||
|
||||
if claims.Name == "" || claims.Sub == "" {
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
uid := pgtype.UUID{}
|
||||
if err := uid.Set(claims.Sub); err != nil {
|
||||
c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Check user existence
|
||||
u, status := h.user.ProfileOrCreate(&user.CreateData{Email: claims.Email, NickName: claims.Nickname})
|
||||
u, status := h.user.ProfileOrCreate(&user.CreateData{UID: uid, NickName: claims.Name})
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// Increment user version
|
||||
version, status := h.user.IncrVersion(u.ID)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
@ -102,11 +129,10 @@ func (h *handler) CallbackHandler() gin.HandlerFunc {
|
||||
}
|
||||
jwt, status := h.jwt.SignClaim(claim)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Figure out a better way to cooperate with frontend
|
||||
c.Redirect(http.StatusFound, "/login?redirect_token="+jwt)
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func RouteRegister(rg *gin.RouterGroup, i *do.Injector) {
|
||||
ClientSecret: conf.WebServer.OAuth.ClientSecret,
|
||||
RedirectURL: conf.WebServer.PublicBase + rg.BasePath() + "/callback",
|
||||
Endpoint: app.provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "roles"},
|
||||
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
|
||||
}
|
||||
|
||||
rg.POST("/login", app.LoginHandler())
|
||||
|
@ -23,7 +23,7 @@ type LoginResponse struct {
|
||||
// @Tags oauth
|
||||
// @Produce json
|
||||
// @Response 200 {object} e.Response[oauth.LoginResponse] "random string"
|
||||
// @Router /oauth/login [post]
|
||||
// @Router /v1/oauth/login [post]
|
||||
func (h *handler) LoginHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
state := utils.RandomString(64)
|
||||
|
@ -1,69 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"git.0x7f.app/WOJ/woj-server/internal/e"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/model"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/service/user"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/mail"
|
||||
)
|
||||
|
||||
type createRequest struct {
|
||||
Email string `form:"email" json:"email" binding:"required"`
|
||||
NickName string `form:"nickname" json:"nickname" binding:"required"`
|
||||
Password string `form:"password" json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
// Create
|
||||
// @Summary create a new user
|
||||
// @Description create a new user
|
||||
// @Tags user
|
||||
// @Accept application/x-www-form-urlencoded
|
||||
// @Produce json
|
||||
// @Param email formData string true "email"
|
||||
// @Param nickname formData string true "nickname"
|
||||
// @Param password formData string true "password"
|
||||
// @Response 200 {object} e.Response[string] "jwt token"
|
||||
// @Router /v1/user/create [post]
|
||||
func (h *handler) Create(c *gin.Context) {
|
||||
req := new(createRequest)
|
||||
if err := c.ShouldBind(req); err != nil {
|
||||
e.Pong(c, e.InvalidParameter, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// verify email is valid
|
||||
_, err := mail.ParseAddress(req.Email)
|
||||
if err != nil {
|
||||
e.Pong[any](c, e.InvalidParameter, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// create user
|
||||
createData := &user.CreateData{
|
||||
Email: req.Email,
|
||||
NickName: req.NickName,
|
||||
Password: req.Password,
|
||||
}
|
||||
u, status := h.userService.Create(createData)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// update version in cache
|
||||
version, status := h.userService.IncrVersion(u.ID)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// sign jwt token
|
||||
claim := &model.Claim{
|
||||
UID: u.ID,
|
||||
Role: u.Role,
|
||||
Version: version,
|
||||
}
|
||||
token, status := h.jwtService.SignClaim(claim)
|
||||
e.Pong(c, status, token)
|
||||
}
|
@ -12,8 +12,6 @@ import (
|
||||
var _ Handler = (*handler)(nil)
|
||||
|
||||
type Handler interface {
|
||||
Create(c *gin.Context)
|
||||
Login(c *gin.Context)
|
||||
Logout(c *gin.Context)
|
||||
Profile(c *gin.Context)
|
||||
}
|
||||
@ -31,8 +29,6 @@ func RouteRegister(rg *gin.RouterGroup, i *do.Injector) {
|
||||
userService: do.MustInvoke[user.Service](i),
|
||||
}
|
||||
|
||||
rg.POST("/create", app.Create)
|
||||
rg.POST("/login", app.Login)
|
||||
rg.POST("/logout", app.jwtService.Handler(true), app.Logout)
|
||||
rg.POST("/profile", app.jwtService.Handler(true), app.Profile)
|
||||
}
|
||||
|
@ -1,61 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"git.0x7f.app/WOJ/woj-server/internal/e"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/model"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/service/user"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `form:"email" json:"email" binding:"required"`
|
||||
Password string `form:"password" json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
NickName string `json:"nickname"`
|
||||
}
|
||||
|
||||
// Login
|
||||
// @Summary login
|
||||
// @Description login and return token
|
||||
// @Tags user
|
||||
// @Accept application/x-www-form-urlencoded
|
||||
// @Produce json
|
||||
// @Param email formData string true "email"
|
||||
// @Param password formData string true "password"
|
||||
// @Response 200 {object} e.Response[LoginResponse] "jwt token and user's nickname"
|
||||
// @Router /v1/user/login [post]
|
||||
func (h *handler) Login(c *gin.Context) {
|
||||
req := new(loginRequest)
|
||||
if err := c.ShouldBind(req); err != nil {
|
||||
e.Pong(c, e.InvalidParameter, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// check password
|
||||
loginData := &user.LoginData{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
}
|
||||
u, status := h.userService.Login(loginData)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// sign and return token
|
||||
version, status := h.userService.IncrVersion(u.ID)
|
||||
if status != e.Success {
|
||||
e.Pong[any](c, status, nil)
|
||||
return
|
||||
}
|
||||
claim := &model.Claim{
|
||||
UID: u.ID,
|
||||
Role: u.Role,
|
||||
Version: version,
|
||||
}
|
||||
token, status := h.jwtService.SignClaim(claim)
|
||||
e.Pong(c, status, LoginResponse{Token: token, NickName: u.NickName})
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
package e
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Status int
|
||||
|
||||
@ -12,6 +15,10 @@ func (code Status) String() string {
|
||||
return msgText[InternalError]
|
||||
}
|
||||
|
||||
func (code Status) QueryString() string {
|
||||
return url.QueryEscape(code.String())
|
||||
}
|
||||
|
||||
func (code Status) AsError() error {
|
||||
return errors.New(code.String())
|
||||
}
|
||||
|
@ -1,14 +1,14 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgtype"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
gorm.Model `json:"meta"`
|
||||
Email string `json:"email" gorm:"not null;uniqueIndex"`
|
||||
NickName string `json:"nick_name" gorm:"not null;uniqueIndex"`
|
||||
UID pgtype.UUID `json:"-" gorm:"not null;uniqueIndex"`
|
||||
NickName string `json:"nick_name" gorm:"not null"`
|
||||
Role Role `json:"role" gorm:"not null"`
|
||||
Password []byte `json:"-"`
|
||||
IsEnabled bool `json:"is_enabled" gorm:"not null;index"`
|
||||
}
|
||||
|
@ -3,33 +3,25 @@ package user
|
||||
import (
|
||||
"git.0x7f.app/WOJ/woj-server/internal/e"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/model"
|
||||
"github.com/jackc/pgtype"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CreateData struct {
|
||||
Email string
|
||||
UID pgtype.UUID
|
||||
NickName string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (s *service) Create(data *CreateData) (*model.User, e.Status) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(data.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
s.log.Warn("BcryptError", zap.Error(err), zap.String("password", data.Password))
|
||||
return nil, e.InternalError
|
||||
}
|
||||
|
||||
user := &model.User{
|
||||
Email: data.Email,
|
||||
UID: data.UID,
|
||||
NickName: data.NickName,
|
||||
Password: hashed,
|
||||
Role: model.RoleGeneral,
|
||||
IsEnabled: true,
|
||||
}
|
||||
|
||||
err = s.db.Get().Create(user).Error
|
||||
err := s.db.Get().Create(user).Error
|
||||
if err != nil && strings.Contains(err.Error(), "duplicate key") {
|
||||
return nil, e.UserDuplicated
|
||||
}
|
||||
|
@ -1,43 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/e"
|
||||
"git.0x7f.app/WOJ/woj-server/internal/model"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LoginData struct {
|
||||
Email string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (s *service) Login(data *LoginData) (*model.User, e.Status) {
|
||||
user := &model.User{Email: data.Email}
|
||||
|
||||
err := s.db.Get().Where(user).First(&user).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, e.UserNotFound
|
||||
}
|
||||
if err != nil {
|
||||
s.log.Warn("DatabaseError", zap.Error(err), zap.Any("user", user))
|
||||
return nil, e.DatabaseError
|
||||
}
|
||||
|
||||
if !user.IsEnabled {
|
||||
return nil, e.UserDisabled
|
||||
}
|
||||
if len(user.Password) == 0 {
|
||||
// created by oauth
|
||||
return nil, e.UserWithoutPassword
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(user.Password, []byte(data.Password))
|
||||
if err != nil {
|
||||
return nil, e.UserWrongPassword
|
||||
}
|
||||
|
||||
return user, e.Success
|
||||
}
|
@ -25,19 +25,26 @@ func (s *service) Profile(uid uint) (*model.User, e.Status) {
|
||||
|
||||
func (s *service) ProfileOrCreate(data *CreateData) (*model.User, e.Status) {
|
||||
user := &model.User{
|
||||
Email: data.Email,
|
||||
UID: data.UID,
|
||||
NickName: data.NickName,
|
||||
Role: model.RoleGeneral,
|
||||
IsEnabled: true,
|
||||
}
|
||||
|
||||
// Notice: FirstOrCreate will not update the record if it exists, and also we should not update the record
|
||||
// Notice: OAuth2 created user will not have password
|
||||
err := s.db.Get().Where(model.User{Email: data.Email}).FirstOrCreate(&user, data).Error
|
||||
// Notice: FirstOrCreate will not update the record if it exists
|
||||
err := s.db.Get().Where(model.User{UID: data.UID}).FirstOrCreate(&user, data).Error
|
||||
if err != nil {
|
||||
s.log.Warn("DatabaseError", zap.Error(err), zap.Any("user", user))
|
||||
return nil, e.DatabaseError
|
||||
}
|
||||
|
||||
if user.NickName != data.NickName {
|
||||
err = s.db.Get().Model(&user).Update("nick_name", data.NickName).Error
|
||||
if err != nil {
|
||||
s.log.Warn("DatabaseError", zap.Error(err), zap.Any("user", user))
|
||||
return nil, e.DatabaseError
|
||||
}
|
||||
}
|
||||
|
||||
return user, e.Success
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ var _ Service = (*service)(nil)
|
||||
|
||||
type Service interface {
|
||||
Create(data *CreateData) (*model.User, e.Status)
|
||||
Login(data *LoginData) (*model.User, e.Status)
|
||||
IncrVersion(uid uint) (int64, e.Status)
|
||||
Profile(uid uint) (*model.User, e.Status)
|
||||
ProfileOrCreate(data *CreateData) (*model.User, e.Status)
|
||||
|
Loading…
Reference in New Issue
Block a user