diff --git a/cmd/server/main.go b/cmd/server/main.go index 7f6d173..0a39c4d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,30 +8,23 @@ import ( "os" "time" + "forgejo.gwairfelin.com/max/gonotes/internal/auth" "forgejo.gwairfelin.com/max/gonotes/internal/conf" "forgejo.gwairfelin.com/max/gonotes/internal/middleware" "forgejo.gwairfelin.com/max/gonotes/internal/notes" "forgejo.gwairfelin.com/max/gonotes/internal/notes/views" - "golang.org/x/oauth2" ) func main() { var confFile string + sessions := middleware.NewSessionStore() + flag.StringVar(&confFile, "c", "/etc/gonotes/conf.toml", "Specify path to config file.") flag.Parse() conf.LoadConfig(confFile) - oauth := &oauth2.Config{ - ClientID: conf.Conf.OIDC.ClientID, - ClientSecret: conf.Conf.OIDC.ClientSecret, - Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL}, - RedirectURL: conf.Conf.OIDC.RedirectURL, - } - - sessions := middleware.NewSessionStore(oauth, "/auth") - err := notes.Init() if err != nil { log.Fatal(err) @@ -41,7 +34,7 @@ func main() { router := http.NewServeMux() notesRouter := views.GetRoutes("/notes") - sessionRouter := sessions.Routes.GetRouter() + authRouter := auth.GetRoutes("/auth", sessions.Login) cacheExpiration, err := time.ParseDuration("24h") if err != nil { @@ -50,9 +43,10 @@ func main() { etag := middleware.NewETag("static", cacheExpiration) + router.Handle("/logout/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.HandlerFunc(sessions.Logout)))) router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound))) router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter)))) - router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", sessionRouter)))) + router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", authRouter)))) router.Handle( "/static/", middleware.LoggingMiddleware( diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 6455a12..d3b9992 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -6,43 +6,124 @@ import ( "encoding/base64" "encoding/json" "fmt" + "log" "net/http" + urls "forgejo.gwairfelin.com/max/gispatcho" "forgejo.gwairfelin.com/max/gonotes/internal/conf" "golang.org/x/oauth2" ) -func GenerateStateOAUTHCookie(w http.ResponseWriter, prefix string) string { +var myurls urls.URLs +var oauthConfig *oauth2.Config +var loginFunction func(user string, w http.ResponseWriter) string + +type userInfo struct { + preferred_username string +} + +func GetRoutes(prefix string, _loginFunction func(user string, w http.ResponseWriter) string) *http.ServeMux { + loginFunction = _loginFunction + + oauthConfig = &oauth2.Config{ + ClientID: conf.Conf.OIDC.ClientID, + ClientSecret: conf.Conf.OIDC.ClientSecret, + Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL}, + RedirectURL: conf.Conf.OIDC.RedirectURL, + } + + myurls = urls.URLs{ + Prefix: prefix, + URLs: map[string]urls.URL{ + "login": {Path: "/oauth/login/", Protocol: "GET", Handler: oauthLogin}, + "callback": {Path: "/oauth/callback/", Protocol: "GET", Handler: oauthCallback}, + }, + } + return myurls.GetRouter() +} + +func oauthLogin(w http.ResponseWriter, r *http.Request) { + log.Printf("%+v", *oauthConfig) + + oauthState := generateStateOAUTHCookie(w) + + url := oauthConfig.AuthCodeURL(oauthState) + log.Printf("Redirecting to %s", url) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func generateStateOAUTHCookie(w http.ResponseWriter) string { b := make([]byte, 16) rand.Read(b) state := base64.URLEncoding.EncodeToString(b) cookie := http.Cookie{ Name: "oauthstate", Value: state, - MaxAge: 30, Secure: true, HttpOnly: true, Path: prefix, + MaxAge: 30, Secure: true, HttpOnly: true, Path: "/auth/oauth/", } http.SetCookie(w, &cookie) return state } -func GetUserFromForgejo(code string, oauth *oauth2.Config) (string, error) { +func oauthCallback(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + log.Printf("An error occured during login: %s", err) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + log.Printf("%v", oauthState) + + if r.FormValue("state") != oauthState.Value { + log.Println("invalid oauth state") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + data, err := getUserData(r.FormValue("code")) + if err != nil { + log.Println(err.Error()) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + username, ok := data["preferred_username"] + if !ok { + log.Println("No username in auth response") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + userStr, ok := username.(string) + if !ok { + log.Println("Username not interpretable as string") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + loginFunction(userStr, w) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) +} + +func getUserData(code string) (map[string]any, error) { // Use code to get token and get user info from Google. - token, err := oauth.Exchange(context.Background(), code) + token, err := oauthConfig.Exchange(context.Background(), code) if err != nil { - return "", fmt.Errorf("code exchange wrong: %s", err.Error()) + return nil, fmt.Errorf("code exchange wrong: %s", err.Error()) } request, err := http.NewRequest("GET", conf.Conf.OIDC.UserinfoURL, nil) if err != nil { - return "", fmt.Errorf("failed to init http client for userinfo: %s", err.Error()) + return nil, fmt.Errorf("failed to init http client for userinfo: %s", err.Error()) } request.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken)) response, err := http.DefaultClient.Do(request) if err != nil { - return "", fmt.Errorf("failed getting user info: %s", err.Error()) + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) } defer response.Body.Close() @@ -50,17 +131,10 @@ func GetUserFromForgejo(code string, oauth *oauth2.Config) (string, error) { err = json.NewDecoder(response.Body).Decode(&uInf) if err != nil { - return "", fmt.Errorf("failed to parse response as json: %s", err.Error()) + return nil, fmt.Errorf("failed to parse response as json: %s", err.Error()) } - username, ok := uInf["preferred_username"] - if !ok { - return "", fmt.Errorf("no username in response: %s", err.Error()) - } - userStr, ok := username.(string) - if !ok { - return "", fmt.Errorf("username not a string: %s", err.Error()) - } + log.Printf("Contents of user data response %s", uInf) - return userStr, nil + return uInf, nil } diff --git a/internal/conf/templates/base.tmpl.html b/internal/conf/templates/base.tmpl.html index 6d776ff..8ce885a 100644 --- a/internal/conf/templates/base.tmpl.html +++ b/internal/conf/templates/base.tmpl.html @@ -64,11 +64,11 @@ {{template "navLinks" .}} {{if eq .user "anon"}}
  • - Login + Login
  • {{else}}
  • - Logout {{.user}} + Logout {{.user}}
  • {{end}} diff --git a/internal/middleware/session.go b/internal/middleware/session.go index a3af9fc..8a11c83 100644 --- a/internal/middleware/session.go +++ b/internal/middleware/session.go @@ -4,12 +4,7 @@ package middleware import ( "context" "crypto/rand" - "log" "net/http" - - urls "forgejo.gwairfelin.com/max/gispatcho" - "forgejo.gwairfelin.com/max/gonotes/internal/auth" - "golang.org/x/oauth2" ) type Session struct { @@ -18,29 +13,14 @@ type Session struct { type SessionStore struct { sessions map[string]Session - oauth *oauth2.Config - Routes urls.URLs } type ContextKey string -func NewSessionStore(oauth *oauth2.Config, prefix string) SessionStore { - store := SessionStore{ - sessions: make(map[string]Session, 10), - oauth: oauth, - } - store.Routes = urls.URLs{ - Prefix: prefix, - URLs: map[string]urls.URL{ - "login": {Path: "/login/", Protocol: "GET", Handler: store.LoginViewOAUTH}, - "callback": {Path: "/callback/", Protocol: "GET", Handler: store.CallbackViewOAUTH}, - "logout": {Path: "/logout/", Protocol: "GET", Handler: store.LogoutView}, - }, - } - return store +func NewSessionStore() SessionStore { + return SessionStore{sessions: make(map[string]Session, 10)} } -// Log a user in func (s *SessionStore) Login(user string, w http.ResponseWriter) string { sessionID := rand.Text() s.sessions[sessionID] = Session{User: user} @@ -54,8 +34,7 @@ func (s *SessionStore) Login(user string, w http.ResponseWriter) string { return sessionID } -// View to logout a user -func (s *SessionStore) LogoutView(w http.ResponseWriter, r *http.Request) { +func (s *SessionStore) Logout(w http.ResponseWriter, r *http.Request) { session := r.Context().Value(ContextKey("session")).(string) delete(s.sessions, session) @@ -68,48 +47,6 @@ func (s *SessionStore) LogoutView(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } -// View to log in a user via oauth -func (s *SessionStore) LoginViewOAUTH(w http.ResponseWriter, r *http.Request) { - log.Printf("%+v", *s.oauth) - - oauthState := auth.GenerateStateOAUTHCookie(w, s.Routes.Prefix) - - url := s.oauth.AuthCodeURL(oauthState) - log.Printf("Redirecting to %s", url) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) -} - -// Oauth callback view -func (s *SessionStore) CallbackViewOAUTH(w http.ResponseWriter, r *http.Request) { - // Read oauthState from Cookie - oauthState, err := r.Cookie("oauthstate") - if err != nil { - log.Printf("An error occured during login: %s", err) - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - return - } - - log.Printf("%v", oauthState) - - if r.FormValue("state") != oauthState.Value { - log.Println("invalid oauth state") - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - return - } - - username, err := auth.GetUserFromForgejo(r.FormValue("code"), s.oauth) - if err != nil { - log.Println(err.Error()) - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - return - } - - s.Login(username, w) - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) -} - -// Turn the session store into a middleware. -// Sets the user on the context based on the available session cookie func (s *SessionStore) AsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sessionCookie, err := r.Cookie("id")