Refactor oauth login
This commit is contained in:
parent
a750f646a9
commit
d30327817e
4 changed files with 90 additions and 90 deletions
|
|
@ -8,23 +8,30 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/auth"
|
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/conf"
|
"forgejo.gwairfelin.com/max/gonotes/internal/conf"
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/middleware"
|
"forgejo.gwairfelin.com/max/gonotes/internal/middleware"
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/notes"
|
"forgejo.gwairfelin.com/max/gonotes/internal/notes"
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/notes/views"
|
"forgejo.gwairfelin.com/max/gonotes/internal/notes/views"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var confFile string
|
var confFile string
|
||||||
|
|
||||||
sessions := middleware.NewSessionStore()
|
|
||||||
|
|
||||||
flag.StringVar(&confFile, "c", "/etc/gonotes/conf.toml", "Specify path to config file.")
|
flag.StringVar(&confFile, "c", "/etc/gonotes/conf.toml", "Specify path to config file.")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
conf.LoadConfig(confFile)
|
conf.LoadConfig(confFile)
|
||||||
|
|
||||||
|
oauth := &oauth2.Config{
|
||||||
|
ClientID: conf.Conf.OIDC.ClientID,
|
||||||
|
ClientSecret: conf.Conf.OIDC.ClientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL},
|
||||||
|
RedirectURL: conf.Conf.OIDC.RedirectURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions := middleware.NewSessionStore(oauth, "/auth")
|
||||||
|
|
||||||
err := notes.Init()
|
err := notes.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
|
@ -34,7 +41,7 @@ func main() {
|
||||||
|
|
||||||
router := http.NewServeMux()
|
router := http.NewServeMux()
|
||||||
notesRouter := views.GetRoutes("/notes")
|
notesRouter := views.GetRoutes("/notes")
|
||||||
authRouter := auth.GetRoutes("/auth", sessions.Login)
|
sessionRouter := sessions.Routes.GetRouter()
|
||||||
|
|
||||||
cacheExpiration, err := time.ParseDuration("24h")
|
cacheExpiration, err := time.ParseDuration("24h")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -43,10 +50,9 @@ func main() {
|
||||||
|
|
||||||
etag := middleware.NewETag("static", cacheExpiration)
|
etag := middleware.NewETag("static", cacheExpiration)
|
||||||
|
|
||||||
router.Handle("/logout/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.HandlerFunc(sessions.Logout))))
|
|
||||||
router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound)))
|
router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound)))
|
||||||
router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter))))
|
router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter))))
|
||||||
router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", authRouter))))
|
router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", sessionRouter))))
|
||||||
router.Handle(
|
router.Handle(
|
||||||
"/static/",
|
"/static/",
|
||||||
middleware.LoggingMiddleware(
|
middleware.LoggingMiddleware(
|
||||||
|
|
|
||||||
|
|
@ -9,108 +9,32 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
urls "forgejo.gwairfelin.com/max/gispatcho"
|
|
||||||
"forgejo.gwairfelin.com/max/gonotes/internal/conf"
|
"forgejo.gwairfelin.com/max/gonotes/internal/conf"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var myurls urls.URLs
|
|
||||||
var oauthConfig *oauth2.Config
|
|
||||||
var loginFunction func(user string, w http.ResponseWriter) string
|
|
||||||
|
|
||||||
type userInfo struct {
|
type userInfo struct {
|
||||||
preferred_username string
|
preferred_username string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRoutes(prefix string, _loginFunction func(user string, w http.ResponseWriter) string) *http.ServeMux {
|
func GenerateStateOAUTHCookie(w http.ResponseWriter, prefix string) string {
|
||||||
loginFunction = _loginFunction
|
|
||||||
|
|
||||||
oauthConfig = &oauth2.Config{
|
|
||||||
ClientID: conf.Conf.OIDC.ClientID,
|
|
||||||
ClientSecret: conf.Conf.OIDC.ClientSecret,
|
|
||||||
Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL},
|
|
||||||
RedirectURL: conf.Conf.OIDC.RedirectURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
myurls = urls.URLs{
|
|
||||||
Prefix: prefix,
|
|
||||||
URLs: map[string]urls.URL{
|
|
||||||
"login": {Path: "/oauth/login/", Protocol: "GET", Handler: oauthLogin},
|
|
||||||
"callback": {Path: "/oauth/callback/", Protocol: "GET", Handler: oauthCallback},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return myurls.GetRouter()
|
|
||||||
}
|
|
||||||
|
|
||||||
func oauthLogin(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log.Printf("%+v", *oauthConfig)
|
|
||||||
|
|
||||||
oauthState := generateStateOAUTHCookie(w)
|
|
||||||
|
|
||||||
url := oauthConfig.AuthCodeURL(oauthState)
|
|
||||||
log.Printf("Redirecting to %s", url)
|
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateStateOAUTHCookie(w http.ResponseWriter) string {
|
|
||||||
|
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
state := base64.URLEncoding.EncodeToString(b)
|
state := base64.URLEncoding.EncodeToString(b)
|
||||||
cookie := http.Cookie{
|
cookie := http.Cookie{
|
||||||
Name: "oauthstate", Value: state,
|
Name: "oauthstate", Value: state,
|
||||||
MaxAge: 30, Secure: true, HttpOnly: true, Path: "/auth/oauth/",
|
MaxAge: 30, Secure: true, HttpOnly: true, Path: prefix,
|
||||||
}
|
}
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
func oauthCallback(w http.ResponseWriter, r *http.Request) {
|
func GetUserData(code string, oauth *oauth2.Config) (map[string]any, error) {
|
||||||
// Read oauthState from Cookie
|
|
||||||
oauthState, err := r.Cookie("oauthstate")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("An error occured during login: %s", err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("%v", oauthState)
|
|
||||||
|
|
||||||
if r.FormValue("state") != oauthState.Value {
|
|
||||||
log.Println("invalid oauth state")
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := getUserData(r.FormValue("code"))
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err.Error())
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
username, ok := data["preferred_username"]
|
|
||||||
if !ok {
|
|
||||||
log.Println("No username in auth response")
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userStr, ok := username.(string)
|
|
||||||
if !ok {
|
|
||||||
log.Println("Username not interpretable as string")
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loginFunction(userStr, w)
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserData(code string) (map[string]any, error) {
|
|
||||||
// Use code to get token and get user info from Google.
|
// Use code to get token and get user info from Google.
|
||||||
|
|
||||||
token, err := oauthConfig.Exchange(context.Background(), code)
|
token, err := oauth.Exchange(context.Background(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("code exchange wrong: %s", err.Error())
|
return nil, fmt.Errorf("code exchange wrong: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,11 +64,11 @@
|
||||||
{{template "navLinks" .}}
|
{{template "navLinks" .}}
|
||||||
{{if eq .user "anon"}}
|
{{if eq .user "anon"}}
|
||||||
<li>
|
<li>
|
||||||
<a class="nav-link" href="/auth/oauth/login/">Login</a>
|
<a class="nav-link" href="/auth/login/">Login</a>
|
||||||
</li>
|
</li>
|
||||||
{{else}}
|
{{else}}
|
||||||
<li>
|
<li>
|
||||||
<a class="nav-link" href="/logout/">Logout {{.user}}</a>
|
<a class="nav-link" href="/auth/logout/">Logout {{.user}}</a>
|
||||||
</li>
|
</li>
|
||||||
{{end}}
|
{{end}}
|
||||||
</ul>
|
</ul>
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,12 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
urls "forgejo.gwairfelin.com/max/gispatcho"
|
||||||
|
"forgejo.gwairfelin.com/max/gonotes/internal/auth"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
|
@ -13,12 +18,26 @@ type Session struct {
|
||||||
|
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
sessions map[string]Session
|
sessions map[string]Session
|
||||||
|
oauth *oauth2.Config
|
||||||
|
Routes urls.URLs
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextKey string
|
type ContextKey string
|
||||||
|
|
||||||
func NewSessionStore() SessionStore {
|
func NewSessionStore(oauth *oauth2.Config, prefix string) SessionStore {
|
||||||
return SessionStore{sessions: make(map[string]Session, 10)}
|
store := SessionStore{
|
||||||
|
sessions: make(map[string]Session, 10),
|
||||||
|
oauth: oauth,
|
||||||
|
}
|
||||||
|
store.Routes = urls.URLs{
|
||||||
|
Prefix: prefix,
|
||||||
|
URLs: map[string]urls.URL{
|
||||||
|
"login": {Path: "/login/", Protocol: "GET", Handler: store.LoginViewOAUTH},
|
||||||
|
"callback": {Path: "/callback/", Protocol: "GET", Handler: store.CallbackViewOAUTH},
|
||||||
|
"logout": {Path: "/logout/", Protocol: "GET", Handler: store.Logout},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionStore) Login(user string, w http.ResponseWriter) string {
|
func (s *SessionStore) Login(user string, w http.ResponseWriter) string {
|
||||||
|
|
@ -47,6 +66,57 @@ func (s *SessionStore) Logout(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) LoginViewOAUTH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Printf("%+v", *s.oauth)
|
||||||
|
|
||||||
|
oauthState := auth.GenerateStateOAUTHCookie(w, s.Routes.Prefix)
|
||||||
|
|
||||||
|
url := s.oauth.AuthCodeURL(oauthState)
|
||||||
|
log.Printf("Redirecting to %s", url)
|
||||||
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) CallbackViewOAUTH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Read oauthState from Cookie
|
||||||
|
oauthState, err := r.Cookie("oauthstate")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("An error occured during login: %s", err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%v", oauthState)
|
||||||
|
|
||||||
|
if r.FormValue("state") != oauthState.Value {
|
||||||
|
log.Println("invalid oauth state")
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := auth.GetUserData(r.FormValue("code"), s.oauth)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err.Error())
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username, ok := data["preferred_username"]
|
||||||
|
if !ok {
|
||||||
|
log.Println("No username in auth response")
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userStr, ok := username.(string)
|
||||||
|
if !ok {
|
||||||
|
log.Println("Username not interpretable as string")
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Login(userStr, w)
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SessionStore) AsMiddleware(next http.Handler) http.Handler {
|
func (s *SessionStore) AsMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionCookie, err := r.Cookie("id")
|
sessionCookie, err := r.Cookie("id")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue