package oauth import ( "context" "fmt" "git.0x7f.app/WOJ/woj-server/internal/e" "git.0x7f.app/WOJ/woj-server/internal/model" "git.0x7f.app/WOJ/woj-server/internal/service/user" "github.com/gin-gonic/gin" "github.com/jackc/pgtype" "net/http" ) // CallbackHandler // @Summary Callback with OAuth2 // @Description Callback endpoint from OAuth2 // @Tags oauth // @Produce json // @Router /v1/oauth/callback [get] func (h *handler) CallbackHandler() gin.HandlerFunc { // TODO: Figure out a better way to cooperate with frontend // Currently using /login?redirect_token=xxx to pass jwt token // /error?message=xxx to pass error message return func(c *gin.Context) { // Extract key from cookie key, err := c.Cookie(oauthStateCookieName) if err != nil { c.Redirect(http.StatusFound, "/error?message="+e.InvalidParameter.QueryString()) return } // Get state from redis key = fmt.Sprintf(oauthStateKey, key) expected, err := h.cache.Get().Get(context.Background(), key).Result() if err != nil { c.Redirect(http.StatusFound, "/error?message="+e.RedisError.QueryString()) return } // Whether state is valid, delete it h.cache.Get().Unlink(context.Background(), key) // Verify state if c.Query("state") != expected { c.Redirect(http.StatusFound, "/error?message="+e.OAuthStateMismatch.QueryString()) return } // Exchange code for token token, err := h.conf.Exchange(context.Background(), c.Query("code")) if err != nil { c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString()) return } // Extract the ID Token from OAuth2 token. raw, ok := token.Extra("id_token").(string) if !ok { c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString()) return } // Parse and verify ID Token payload. idToken, err := h.verifier.Verify(context.Background(), raw) if err != nil { c.Redirect(http.StatusFound, "/error?message="+e.OAuthVerifyFailed.QueryString()) return } // Extract custom claims // TODO: extract role from claims: need to modify oidc provider var claims struct { Sub string `json:"sub"` Name string `json:"name"` // Exp uint64 `json:"exp"` // Iat uint64 `json:"iat"` // AuthTime uint64 `json:"auth_time"` // Jti string `json:"jti"` // Iss string `json:"iss"` // Aud string `json:"aud"` // Typ string `json:"typ"` // Azp string `json:"azp"` // SessionState string `json:"session_state"` // AtHash string `json:"at_hash"` // Acr string `json:"acr"` // Sid string `json:"sid"` // PreferredUsername string `json:"preferred_username"` // GivenName string `json:"given_name"` // FamilyName string `json:"family_name"` } if err := idToken.Claims(&claims); err != nil { c.Redirect(http.StatusFound, "/error?message="+e.OAuthGetClaimsFailed.QueryString()) return } if claims.Name == "" || claims.Sub == "" { c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString()) return } uid := pgtype.UUID{} if err := uid.Set(claims.Sub); err != nil { c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString()) return } // Check user existence u, status := h.user.ProfileOrCreate(&user.CreateData{UID: uid, NickName: claims.Name}) if status != e.Success { c.Redirect(http.StatusFound, "/error?message="+status.QueryString()) return } // Increment user version version, status := h.user.IncrVersion(u.ID) if status != e.Success { c.Redirect(http.StatusFound, "/error?message="+status.QueryString()) return } // Sign JWT token claim := &model.Claim{ UID: u.ID, Role: u.Role, Version: version, } jwt, status := h.jwt.SignClaim(claim) if status != e.Success { c.Redirect(http.StatusFound, "/error?message="+status.QueryString()) return } c.Redirect(http.StatusFound, "/login?redirect_token="+jwt) } }