chore: move oauth into api

This commit is contained in:
Paul Pan 2024-01-05 00:57:43 +08:00
parent 310eff0e88
commit bb21f5858d
Signed by: Paul
GPG Key ID: D639BDF5BA578AF4
7 changed files with 86 additions and 128 deletions

View File

@ -17,7 +17,6 @@ import (
"git.0x7f.app/WOJ/woj-server/internal/service/user" "git.0x7f.app/WOJ/woj-server/internal/service/user"
"git.0x7f.app/WOJ/woj-server/internal/web/jwt" "git.0x7f.app/WOJ/woj-server/internal/web/jwt"
"git.0x7f.app/WOJ/woj-server/internal/web/metrics" "git.0x7f.app/WOJ/woj-server/internal/web/metrics"
"git.0x7f.app/WOJ/woj-server/internal/web/oauth"
"git.0x7f.app/WOJ/woj-server/internal/web/router" "git.0x7f.app/WOJ/woj-server/internal/web/router"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/samber/do" "github.com/samber/do"
@ -76,7 +75,6 @@ func prepareServices(c *cli.Context) *do.Injector {
{ // web helper services { // web helper services
do.Provide(injector, metrics.NewService) do.Provide(injector, metrics.NewService)
do.Provide(injector, jwt.NewService) do.Provide(injector, jwt.NewService)
do.Provide(injector, oauth.NewService)
do.Provide(injector, router.NewService) do.Provide(injector, router.NewService)
} }

View File

@ -16,7 +16,7 @@ import (
// @Tags oauth // @Tags oauth
// @Produce json // @Produce json
// @Router /oauth/callback [get] // @Router /oauth/callback [get]
func (s *service) CallbackHandler() gin.HandlerFunc { func (h *handler) CallbackHandler() gin.HandlerFunc {
// TODO: we are returning e.Response directly here, we should redirect to a trampoline page, passing the response as query string // TODO: we are returning e.Response directly here, we should redirect to a trampoline page, passing the response as query string
return func(c *gin.Context) { return func(c *gin.Context) {
@ -29,14 +29,14 @@ func (s *service) CallbackHandler() gin.HandlerFunc {
// Get state from redis // Get state from redis
key = fmt.Sprintf(oauthStateKey, key) key = fmt.Sprintf(oauthStateKey, key)
expected, err := s.cache.Get().Get(context.Background(), key).Result() expected, err := h.cache.Get().Get(context.Background(), key).Result()
if err != nil { if err != nil {
e.Pong[any](c, e.RedisError, nil) e.Pong[any](c, e.RedisError, nil)
return return
} }
// Whether state is valid, delete it // Whether state is valid, delete it
s.cache.Get().Unlink(context.Background(), key) h.cache.Get().Unlink(context.Background(), key)
c.SetCookie(oauthStateCookieName, "", -1, "/", "", false, true) c.SetCookie(oauthStateCookieName, "", -1, "/", "", false, true)
// Verify state // Verify state
@ -46,7 +46,7 @@ func (s *service) CallbackHandler() gin.HandlerFunc {
} }
// Exchange code for token // Exchange code for token
token, err := s.conf.Exchange(context.Background(), c.Query("code")) token, err := h.conf.Exchange(context.Background(), c.Query("code"))
if err != nil { if err != nil {
e.Pong[any](c, e.OAuthExchangeFailed, nil) e.Pong[any](c, e.OAuthExchangeFailed, nil)
return return
@ -60,7 +60,7 @@ func (s *service) CallbackHandler() gin.HandlerFunc {
} }
// Parse and verify ID Token payload. // Parse and verify ID Token payload.
idToken, err := s.verifier.Verify(context.Background(), raw) idToken, err := h.verifier.Verify(context.Background(), raw)
if err != nil { if err != nil {
e.Pong[any](c, e.OAuthVerifyFailed, nil) e.Pong[any](c, e.OAuthVerifyFailed, nil)
return return
@ -85,14 +85,14 @@ func (s *service) CallbackHandler() gin.HandlerFunc {
} }
// Check user existence // Check user existence
u, status := s.user.ProfileOrCreate(&user.CreateData{UserName: claims.Email, NickName: claims.Nickname}) u, status := h.user.ProfileOrCreate(&user.CreateData{UserName: claims.Email, NickName: claims.Nickname})
if status != e.Success { if status != e.Success {
e.Pong[any](c, status, nil) e.Pong[any](c, status, nil)
return return
} }
// Increment user version // Increment user version
version, status := s.user.IncrVersion(u.ID) version, status := h.user.IncrVersion(u.ID)
if status != e.Success { if status != e.Success {
e.Pong[any](c, status, nil) e.Pong[any](c, status, nil)
return return
@ -104,7 +104,7 @@ func (s *service) CallbackHandler() gin.HandlerFunc {
Role: u.Role, Role: u.Role,
Version: version, Version: version,
} }
jwt, status := s.jwt.SignClaim(claim) jwt, status := h.jwt.SignClaim(claim)
if status != e.Success { if status != e.Success {
e.Pong[any](c, status, nil) e.Pong[any](c, status, nil)
return return

View File

@ -0,0 +1,71 @@
package oauth
import (
"context"
"git.0x7f.app/WOJ/woj-server/internal/misc/config"
"git.0x7f.app/WOJ/woj-server/internal/misc/log"
"git.0x7f.app/WOJ/woj-server/internal/repo/cache"
"git.0x7f.app/WOJ/woj-server/internal/service/user"
"git.0x7f.app/WOJ/woj-server/internal/web/jwt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin"
"github.com/samber/do"
"go.uber.org/zap"
"golang.org/x/oauth2"
"time"
)
type Handler interface {
LoginHandler() gin.HandlerFunc
CallbackHandler() gin.HandlerFunc
}
const (
oauthStateCookieName = "oauth_state"
oauthStateKey = "OAuthState:%s"
oauthStateLiveness = 15 * time.Minute
)
func RouteRegister(rg *gin.RouterGroup, i *do.Injector) {
conf := do.MustInvoke[config.Service](i).GetConfig()
if conf.WebServer.OAuth.Domain == "" {
return
}
app := &handler{}
app.log = do.MustInvoke[log.Service](i).GetLogger("oauth")
app.jwt = do.MustInvoke[jwt.Service](i)
app.user = do.MustInvoke[user.Service](i)
app.cache = do.MustInvoke[cache.Service](i)
var err error
app.provider, err = oidc.NewProvider(context.Background(), conf.WebServer.OAuth.Domain)
if err != nil {
app.log.Error("failed to create oauth provider", zap.Error(err), zap.String("domain", conf.WebServer.OAuth.Domain))
return
}
app.verifier = app.provider.Verifier(&oidc.Config{ClientID: conf.WebServer.OAuth.ClientID})
app.conf = oauth2.Config{
ClientID: conf.WebServer.OAuth.ClientID,
ClientSecret: conf.WebServer.OAuth.ClientSecret,
RedirectURL: conf.WebServer.PublicBase + rg.BasePath() + "/callback",
Endpoint: app.provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "roles"},
}
rg.GET("/login", app.LoginHandler())
rg.GET("/callback", app.CallbackHandler())
}
type handler struct {
log *zap.Logger
jwt jwt.Service
user user.Service
cache cache.Service
provider *oidc.Provider
conf oauth2.Config
verifier *oidc.IDTokenVerifier
}

View File

@ -7,12 +7,6 @@ import (
"git.0x7f.app/WOJ/woj-server/pkg/utils" "git.0x7f.app/WOJ/woj-server/pkg/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"time"
)
const (
oauthStateCookieName = "oauth_state"
oauthStateKey = "OAuthState:%s"
) )
// LoginHandler // LoginHandler
@ -22,21 +16,21 @@ const (
// @Produce json // @Produce json
// @Response 200 {object} e.Response[string] "random string" // @Response 200 {object} e.Response[string] "random string"
// @Router /oauth/login [post] // @Router /oauth/login [post]
func (s *service) LoginHandler() gin.HandlerFunc { func (h *handler) LoginHandler() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
state := utils.RandomString(64) state := utils.RandomString(64)
key := utils.RandomString(16) key := utils.RandomString(16)
err := s.cache.Get().Set(context.Background(), fmt.Sprintf(oauthStateKey, key), state, 15*time.Minute).Err() err := h.cache.Get().Set(context.Background(), fmt.Sprintf(oauthStateKey, key), state, oauthStateLiveness).Err()
if err != nil { if err != nil {
e.Pong[any](c, e.RedisError, nil) e.Pong[any](c, e.RedisError, nil)
return return
} }
c.SetSameSite(http.SameSiteStrictMode) c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(oauthStateCookieName, key, 15*60, "/", "", false, true) c.SetCookie(oauthStateCookieName, key, int(oauthStateLiveness.Seconds()), "/", "", false, true)
url := s.conf.AuthCodeURL(state) url := h.conf.AuthCodeURL(state)
e.Pong(c, e.Success, url) e.Pong(c, e.Success, url)
} }
} }

View File

@ -1,95 +0,0 @@
package oauth
import (
"context"
"git.0x7f.app/WOJ/woj-server/internal/misc/config"
"git.0x7f.app/WOJ/woj-server/internal/misc/log"
"git.0x7f.app/WOJ/woj-server/internal/repo/cache"
"git.0x7f.app/WOJ/woj-server/internal/service/user"
"git.0x7f.app/WOJ/woj-server/internal/web/jwt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin"
"github.com/samber/do"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
type Service interface {
LoginHandler() gin.HandlerFunc
CallbackHandler() gin.HandlerFunc
IsEnabled() bool
GetLoginPath() string
GetCallbackPath() string
HealthCheck() error
}
const (
basePath = "/oauth"
callbackPath = basePath + "/callback"
loginPath = basePath + "/login"
)
func NewService(i *do.Injector) (Service, error) {
srv := &service{}
srv.log = do.MustInvoke[log.Service](i).GetLogger("oauth")
srv.jwt = do.MustInvoke[jwt.Service](i)
srv.user = do.MustInvoke[user.Service](i)
srv.cache = do.MustInvoke[cache.Service](i)
srv.enabled = false
conf := do.MustInvoke[config.Service](i).GetConfig()
if conf.WebServer.OAuth.Domain == "" {
return srv, srv.err
}
srv.provider, srv.err = oidc.NewProvider(context.Background(), conf.WebServer.OAuth.Domain)
if srv.err != nil {
srv.log.Error("failed to create oauth provider", zap.Error(srv.err), zap.String("domain", conf.WebServer.OAuth.Domain))
return srv, srv.err
}
srv.verifier = srv.provider.Verifier(&oidc.Config{ClientID: conf.WebServer.OAuth.ClientID})
srv.conf = oauth2.Config{
ClientID: conf.WebServer.OAuth.ClientID,
ClientSecret: conf.WebServer.OAuth.ClientSecret,
RedirectURL: conf.WebServer.PublicBase + callbackPath,
Endpoint: srv.provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "roles"},
}
srv.enabled = true
return srv, srv.err
}
type service struct {
log *zap.Logger
jwt jwt.Service
user user.Service
cache cache.Service
provider *oidc.Provider
conf oauth2.Config
verifier *oidc.IDTokenVerifier
enabled bool
err error
}
func (s *service) IsEnabled() bool {
return s.enabled && s.err == nil
}
func (s *service) GetLoginPath() string {
return loginPath
}
func (s *service) GetCallbackPath() string {
return callbackPath
}
func (s *service) HealthCheck() error {
return s.err
}

View File

@ -2,6 +2,7 @@ package router
import ( import (
"git.0x7f.app/WOJ/woj-server/internal/api/debug" "git.0x7f.app/WOJ/woj-server/internal/api/debug"
"git.0x7f.app/WOJ/woj-server/internal/api/oauth"
"git.0x7f.app/WOJ/woj-server/internal/api/problem" "git.0x7f.app/WOJ/woj-server/internal/api/problem"
"git.0x7f.app/WOJ/woj-server/internal/api/status" "git.0x7f.app/WOJ/woj-server/internal/api/status"
"git.0x7f.app/WOJ/woj-server/internal/api/submission" "git.0x7f.app/WOJ/woj-server/internal/api/submission"
@ -30,4 +31,5 @@ var endpoints = []model.EndpointInfo{
{Version: "/v1", Path: "/problem", Register: problem.RouteRegister}, {Version: "/v1", Path: "/problem", Register: problem.RouteRegister},
{Version: "/v1", Path: "/submission", Register: submission.RouteRegister}, {Version: "/v1", Path: "/submission", Register: submission.RouteRegister},
{Version: "/v1", Path: "/status", Register: status.RouteRegister}, {Version: "/v1", Path: "/status", Register: status.RouteRegister},
{Version: "/v1", Path: "/oauth", Register: oauth.RouteRegister},
} }

View File

@ -5,7 +5,6 @@ import (
"git.0x7f.app/WOJ/woj-server/internal/misc/log" "git.0x7f.app/WOJ/woj-server/internal/misc/log"
"git.0x7f.app/WOJ/woj-server/internal/model" "git.0x7f.app/WOJ/woj-server/internal/model"
"git.0x7f.app/WOJ/woj-server/internal/web/metrics" "git.0x7f.app/WOJ/woj-server/internal/web/metrics"
"git.0x7f.app/WOJ/woj-server/internal/web/oauth"
_ "git.0x7f.app/WOJ/woj-server/internal/web/router/docs" _ "git.0x7f.app/WOJ/woj-server/internal/web/router/docs"
"git.0x7f.app/WOJ/woj-server/pkg/utils" "git.0x7f.app/WOJ/woj-server/pkg/utils"
sentrygin "github.com/getsentry/sentry-go/gin" sentrygin "github.com/getsentry/sentry-go/gin"
@ -32,7 +31,6 @@ type Service interface {
func NewService(i *do.Injector) (Service, error) { func NewService(i *do.Injector) (Service, error) {
srv := &service{} srv := &service{}
srv.metric = do.MustInvoke[metrics.Service](i) srv.metric = do.MustInvoke[metrics.Service](i)
srv.oauth = do.MustInvoke[oauth.Service](i)
srv.logger = do.MustInvoke[log.Service](i) srv.logger = do.MustInvoke[log.Service](i)
conf := do.MustInvoke[config.Service](i).GetConfig() conf := do.MustInvoke[config.Service](i).GetConfig()
@ -44,12 +42,8 @@ func NewService(i *do.Injector) (Service, error) {
type service struct { type service struct {
logger log.Service logger log.Service
engine *gin.Engine engine *gin.Engine
// middlewares
metric metrics.Service metric metrics.Service
oauth oauth.Service err error
err error
} }
func (s *service) GetRouter() *gin.Engine { func (s *service) GetRouter() *gin.Engine {
@ -140,12 +134,6 @@ func (s *service) initRouters(conf *model.Config, injector *do.Injector) *gin.En
api := r.Group("/api/") api := r.Group("/api/")
s.setupApi(api, injector) s.setupApi(api, injector)
// oauth2
if s.oauth.IsEnabled() {
r.POST(s.oauth.GetLoginPath(), s.oauth.LoginHandler())
r.GET(s.oauth.GetCallbackPath(), s.oauth.CallbackHandler())
}
// static files // static files
r.Use(static.Serve("/", static.LocalFile("./resource/frontend", true))) r.Use(static.Serve("/", static.LocalFile("./resource/frontend", true)))