diff --git a/internal/web/oauth/callback.go b/internal/web/oauth/callback.go index 3c217b2..baf21bd 100644 --- a/internal/web/oauth/callback.go +++ b/internal/web/oauth/callback.go @@ -2,11 +2,11 @@ package oauth import ( "context" + "fmt" userApi "git.0x7f.app/WOJ/woj-server/internal/api/user" "git.0x7f.app/WOJ/woj-server/internal/e" "git.0x7f.app/WOJ/woj-server/internal/model" "git.0x7f.app/WOJ/woj-server/internal/service/user" - "git.0x7f.app/WOJ/woj-server/pkg/utils" "github.com/gin-gonic/gin" ) @@ -20,15 +20,27 @@ func (s *service) CallbackHandler() gin.HandlerFunc { // TODO: we are returning e.Response directly here, we should redirect to a trampoline page, passing the response as query string return func(c *gin.Context) { - // verify state - signed, err := c.Cookie(oauthStateCookieName) + // Extract key from cookie + key, err := c.Cookie(oauthStateCookieName) if err != nil { e.Pong[any](c, e.InvalidParameter, nil) return } - state := c.Query("state") - if !utils.SignAndCompare(state, signed, []byte(s.conf.ClientSecret)) { + // Get state from redis + key = fmt.Sprintf(oauthStateKey, key) + expected, err := s.cache.Get().Get(context.Background(), key).Result() + if err != nil { + e.Pong[any](c, e.RedisError, nil) + return + } + + // Whether state is valid, delete it + s.cache.Get().Unlink(context.Background(), key) + c.SetCookie(oauthStateCookieName, "", -1, "/", "", false, true) + + // Verify state + if c.Query("state") != expected { e.Pong[any](c, e.OAuthStateMismatch, nil) return } @@ -72,21 +84,21 @@ func (s *service) CallbackHandler() gin.HandlerFunc { return } - // Check User Existence + // Check user existence u, status := s.user.ProfileOrCreate(&user.CreateData{UserName: claims.Email, NickName: claims.Nickname}) if status != e.Success { e.Pong[any](c, status, nil) return } - // Increment User Version + // Increment user version version, status := s.user.IncrVersion(u.ID) if status != e.Success { e.Pong[any](c, status, nil) return } - // Sign JWT Token + // Sign JWT token claim := &model.Claim{ UID: u.ID, Role: u.Role, diff --git a/internal/web/oauth/login.go b/internal/web/oauth/login.go index 44bfeed..4eb9917 100644 --- a/internal/web/oauth/login.go +++ b/internal/web/oauth/login.go @@ -1,14 +1,18 @@ package oauth import ( + "context" + "fmt" "git.0x7f.app/WOJ/woj-server/internal/e" "git.0x7f.app/WOJ/woj-server/pkg/utils" "github.com/gin-gonic/gin" "net/http" + "time" ) const ( oauthStateCookieName = "oauth_state" + oauthStateKey = "OAuthState:%s" ) // LoginHandler @@ -21,12 +25,18 @@ const ( func (s *service) LoginHandler() gin.HandlerFunc { return func(c *gin.Context) { state := utils.RandomString(64) - signed := utils.SignString(state, []byte(s.conf.ClientSecret)) - url := s.conf.AuthCodeURL(state) + key := utils.RandomString(16) + + err := s.cache.Get().Set(context.Background(), fmt.Sprintf(oauthStateKey, key), state, 15*time.Minute).Err() + if err != nil { + e.Pong[any](c, e.RedisError, nil) + return + } c.SetSameSite(http.SameSiteStrictMode) - c.SetCookie(oauthStateCookieName, signed, 15*60, "/", "", false, true) + c.SetCookie(oauthStateCookieName, key, 15*60, "/", "", false, true) + url := s.conf.AuthCodeURL(state) e.Pong(c, e.Success, url) } } diff --git a/internal/web/oauth/service.go b/internal/web/oauth/service.go index b83f142..48e2d90 100644 --- a/internal/web/oauth/service.go +++ b/internal/web/oauth/service.go @@ -4,6 +4,7 @@ import ( "context" "git.0x7f.app/WOJ/woj-server/internal/misc/config" "git.0x7f.app/WOJ/woj-server/internal/misc/log" + "git.0x7f.app/WOJ/woj-server/internal/repo/cache" "git.0x7f.app/WOJ/woj-server/internal/service/user" "git.0x7f.app/WOJ/woj-server/internal/web/jwt" "github.com/coreos/go-oidc/v3/oidc" @@ -34,6 +35,7 @@ func NewService(i *do.Injector) (Service, error) { srv.log = do.MustInvoke[log.Service](i).GetLogger("oauth") srv.jwt = do.MustInvoke[jwt.Service](i) srv.user = do.MustInvoke[user.Service](i) + srv.cache = do.MustInvoke[cache.Service](i) srv.enabled = false conf := do.MustInvoke[config.Service](i).GetConfig() @@ -63,9 +65,10 @@ func NewService(i *do.Injector) (Service, error) { } type service struct { - log *zap.Logger - jwt jwt.Service - user user.Service + log *zap.Logger + jwt jwt.Service + user user.Service + cache cache.Service provider *oidc.Provider conf oauth2.Config