gonotes/internal/middleware/session.go

99 lines
2.1 KiB
Go

// Package middleware to deal with sessions
package middleware
import (
"context"
"crypto/rand"
"net/http"
)
type Session struct {
User string
}
type SessionStore struct {
sessions map[string]Session
}
type ContextKey string
func NewSessionStore() SessionStore {
return SessionStore{sessions: make(map[string]Session, 10)}
}
func (s *SessionStore) Login(user string, w http.ResponseWriter) string {
sessionID := rand.Text()
s.sessions[sessionID] = Session{User: user}
cookie := http.Cookie{
Name: "id", Value: sessionID, MaxAge: 3600,
Secure: true, HttpOnly: true, Path: "/",
}
http.SetCookie(w, &cookie)
return sessionID
}
func (s *SessionStore) Logout(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value(ContextKey("session")).(string)
delete(s.sessions, session)
cookie := http.Cookie{
Name: "id", Value: "", MaxAge: -1,
Secure: true, HttpOnly: true, Path: "/",
}
http.SetCookie(w, &cookie)
}
func (s *SessionStore) AsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionCookie, err := r.Cookie("id")
// No session yet
if err != nil {
user := r.Header.Get("X-Auth-Request-User")
if user != "" {
sessionID := s.Login(user, w)
nextWithSessionContext(w, r, next, user, sessionID)
} else {
http.Redirect(w, r, "/login/", http.StatusFound)
}
return
}
session, ok := s.sessions[sessionCookie.Value]
// Session expired
if !ok {
user := r.Header.Get("X-Auth-Request-User")
if user != "" {
sessionID := s.Login(user, w)
nextWithSessionContext(w, r, next, user, sessionID)
} else {
http.Redirect(w, r, "/login/", http.StatusFound)
}
return
}
nextWithSessionContext(w, r, next, session.User, sessionCookie.Value)
})
}
func nextWithSessionContext(w http.ResponseWriter, r *http.Request, next http.Handler, user string, sessionID string) {
ctx := r.Context()
ctx = context.WithValue(
context.WithValue(
ctx,
ContextKey("user"),
user,
),
ContextKey("session"),
sessionID,
)
next.ServeHTTP(w, r.WithContext(ctx))
}