woj-server/internal/api/oauth/callback.go

139 lines
4.0 KiB
Go
Raw Permalink Normal View History

2024-01-03 00:55:41 +08:00
package oauth
import (
"context"
"fmt"
2024-01-03 00:55:41 +08:00
"git.0x7f.app/WOJ/woj-server/internal/e"
"git.0x7f.app/WOJ/woj-server/internal/model"
"git.0x7f.app/WOJ/woj-server/internal/service/user"
"github.com/gin-gonic/gin"
2024-04-28 22:58:39 +08:00
"github.com/jackc/pgtype"
"net/http"
2024-01-03 00:55:41 +08:00
)
// CallbackHandler
// @Summary Callback with OAuth2
// @Description Callback endpoint from OAuth2
// @Tags oauth
// @Produce json
2024-04-28 22:58:39 +08:00
// @Router /v1/oauth/callback [get]
2024-01-05 00:57:43 +08:00
func (h *handler) CallbackHandler() gin.HandlerFunc {
2024-04-28 22:58:39 +08:00
// TODO: Figure out a better way to cooperate with frontend
// Currently using /login?redirect_token=xxx to pass jwt token
// /error?message=xxx to pass error message
2024-01-03 00:55:41 +08:00
return func(c *gin.Context) {
// Extract key from cookie
key, err := c.Cookie(oauthStateCookieName)
2024-01-03 00:55:41 +08:00
if err != nil {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.InvalidParameter.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Get state from redis
key = fmt.Sprintf(oauthStateKey, key)
2024-01-05 00:57:43 +08:00
expected, err := h.cache.Get().Get(context.Background(), key).Result()
if err != nil {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.RedisError.QueryString())
return
}
// Whether state is valid, delete it
2024-01-05 00:57:43 +08:00
h.cache.Get().Unlink(context.Background(), key)
// Verify state
if c.Query("state") != expected {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.OAuthStateMismatch.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Exchange code for token
2024-01-05 00:57:43 +08:00
token, err := h.conf.Exchange(context.Background(), c.Query("code"))
2024-01-03 00:55:41 +08:00
if err != nil {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Extract the ID Token from OAuth2 token.
raw, ok := token.Extra("id_token").(string)
if !ok {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.OAuthExchangeFailed.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Parse and verify ID Token payload.
2024-01-05 00:57:43 +08:00
idToken, err := h.verifier.Verify(context.Background(), raw)
2024-01-03 00:55:41 +08:00
if err != nil {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.OAuthVerifyFailed.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Extract custom claims
2024-04-28 22:58:39 +08:00
// TODO: extract role from claims: need to modify oidc provider
2024-01-03 00:55:41 +08:00
var claims struct {
2024-04-28 22:58:39 +08:00
Sub string `json:"sub"`
Name string `json:"name"`
// Exp uint64 `json:"exp"`
// Iat uint64 `json:"iat"`
// AuthTime uint64 `json:"auth_time"`
// Jti string `json:"jti"`
// Iss string `json:"iss"`
// Aud string `json:"aud"`
// Typ string `json:"typ"`
// Azp string `json:"azp"`
// SessionState string `json:"session_state"`
// AtHash string `json:"at_hash"`
// Acr string `json:"acr"`
// Sid string `json:"sid"`
// PreferredUsername string `json:"preferred_username"`
// GivenName string `json:"given_name"`
// FamilyName string `json:"family_name"`
2024-01-03 00:55:41 +08:00
}
2024-04-28 22:58:39 +08:00
2024-01-03 00:55:41 +08:00
if err := idToken.Claims(&claims); err != nil {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+e.OAuthGetClaimsFailed.QueryString())
2024-01-03 00:55:41 +08:00
return
}
2024-04-28 22:58:39 +08:00
if claims.Name == "" || claims.Sub == "" {
c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString())
return
}
uid := pgtype.UUID{}
if err := uid.Set(claims.Sub); err != nil {
c.Redirect(http.StatusFound, "/error?message="+e.UserInvalid.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Check user existence
2024-04-28 22:58:39 +08:00
u, status := h.user.ProfileOrCreate(&user.CreateData{UID: uid, NickName: claims.Name})
2024-01-03 00:55:41 +08:00
if status != e.Success {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Increment user version
2024-01-05 00:57:43 +08:00
version, status := h.user.IncrVersion(u.ID)
2024-01-03 00:55:41 +08:00
if status != e.Success {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
2024-01-03 00:55:41 +08:00
return
}
// Sign JWT token
2024-01-03 00:55:41 +08:00
claim := &model.Claim{
UID: u.ID,
Role: u.Role,
Version: version,
}
2024-01-05 00:57:43 +08:00
jwt, status := h.jwt.SignClaim(claim)
2024-01-03 00:55:41 +08:00
if status != e.Success {
2024-04-28 22:58:39 +08:00
c.Redirect(http.StatusFound, "/error?message="+status.QueryString())
2024-01-03 00:55:41 +08:00
return
}
c.Redirect(http.StatusFound, "/login?redirect_token="+jwt)
2024-01-03 00:55:41 +08:00
}
}