159 lines
3.9 KiB
Go
159 lines
3.9 KiB
Go
// Package middleware to deal with sessions
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"log"
|
|
"net/http"
|
|
|
|
urls "forgejo.gwairfelin.com/max/gispatcho"
|
|
"forgejo.gwairfelin.com/max/gonotes/internal/auth"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type Session struct {
|
|
User string
|
|
}
|
|
|
|
type SessionStore struct {
|
|
sessions map[string]Session
|
|
oauth *oauth2.Config
|
|
Routes urls.URLs
|
|
}
|
|
|
|
type ContextKey string
|
|
|
|
func NewSessionStore(oauth *oauth2.Config, prefix string) SessionStore {
|
|
store := SessionStore{
|
|
sessions: make(map[string]Session, 10),
|
|
oauth: oauth,
|
|
}
|
|
store.Routes = urls.URLs{
|
|
Prefix: prefix,
|
|
URLs: map[string]urls.URL{
|
|
"login": {Path: "/login/", Protocol: "GET", Handler: store.LoginViewOAUTH},
|
|
"callback": {Path: "/callback/", Protocol: "GET", Handler: store.CallbackViewOAUTH},
|
|
"logout": {Path: "/logout/", Protocol: "GET", Handler: store.LogoutView},
|
|
},
|
|
}
|
|
return store
|
|
}
|
|
|
|
// Log a user in
|
|
func (s *SessionStore) Login(user string, w http.ResponseWriter) string {
|
|
sessionID := rand.Text()
|
|
s.sessions[sessionID] = Session{User: user}
|
|
|
|
cookie := http.Cookie{
|
|
Name: "id", Value: sessionID, MaxAge: 3600,
|
|
Secure: true, HttpOnly: true, Path: "/",
|
|
}
|
|
|
|
http.SetCookie(w, &cookie)
|
|
return sessionID
|
|
}
|
|
|
|
// View to logout a user
|
|
func (s *SessionStore) LogoutView(w http.ResponseWriter, r *http.Request) {
|
|
session := r.Context().Value(ContextKey("session")).(string)
|
|
|
|
delete(s.sessions, session)
|
|
cookie := http.Cookie{
|
|
Name: "id", Value: "", MaxAge: -1,
|
|
Secure: true, HttpOnly: true, Path: "/",
|
|
}
|
|
|
|
http.SetCookie(w, &cookie)
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
// View to log in a user via oauth
|
|
func (s *SessionStore) LoginViewOAUTH(w http.ResponseWriter, r *http.Request) {
|
|
log.Printf("%+v", *s.oauth)
|
|
|
|
oauthState := auth.GenerateStateOAUTHCookie(w, s.Routes.Prefix)
|
|
|
|
url := s.oauth.AuthCodeURL(oauthState)
|
|
log.Printf("Redirecting to %s", url)
|
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
// Oauth callback view
|
|
func (s *SessionStore) CallbackViewOAUTH(w http.ResponseWriter, r *http.Request) {
|
|
// Read oauthState from Cookie
|
|
oauthState, err := r.Cookie("oauthstate")
|
|
if err != nil {
|
|
log.Printf("An error occured during login: %s", err)
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
log.Printf("%v", oauthState)
|
|
|
|
if r.FormValue("state") != oauthState.Value {
|
|
log.Println("invalid oauth state")
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
data, err := auth.GetUserData(r.FormValue("code"), s.oauth)
|
|
if err != nil {
|
|
log.Println(err.Error())
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
username, ok := data["preferred_username"]
|
|
if !ok {
|
|
log.Println("No username in auth response")
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
userStr, ok := username.(string)
|
|
if !ok {
|
|
log.Println("Username not interpretable as string")
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
|
|
s.Login(userStr, w)
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
// Turn the session store into a middleware.
|
|
// Sets the user on the context based on the available session cookie
|
|
func (s *SessionStore) AsMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sessionCookie, err := r.Cookie("id")
|
|
user := "anon"
|
|
var cookieVal string
|
|
// Session exists
|
|
if err == nil {
|
|
session, ok := s.sessions[sessionCookie.Value]
|
|
|
|
// Session not expired
|
|
if ok {
|
|
user = session.User
|
|
cookieVal = sessionCookie.Value
|
|
}
|
|
}
|
|
|
|
nextWithSessionContext(w, r, next, user, cookieVal)
|
|
})
|
|
}
|
|
|
|
func nextWithSessionContext(w http.ResponseWriter, r *http.Request, next http.Handler, user string, sessionID string) {
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(
|
|
context.WithValue(
|
|
ctx,
|
|
ContextKey("user"),
|
|
user,
|
|
),
|
|
ContextKey("session"),
|
|
sessionID,
|
|
)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|