diff --git a/cmd/server/main.go b/cmd/server/main.go index 7f6d173..0a39c4d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,30 +8,23 @@ import ( "os" "time" + "forgejo.gwairfelin.com/max/gonotes/internal/auth" "forgejo.gwairfelin.com/max/gonotes/internal/conf" "forgejo.gwairfelin.com/max/gonotes/internal/middleware" "forgejo.gwairfelin.com/max/gonotes/internal/notes" "forgejo.gwairfelin.com/max/gonotes/internal/notes/views" - "golang.org/x/oauth2" ) func main() { var confFile string + sessions := middleware.NewSessionStore() + flag.StringVar(&confFile, "c", "/etc/gonotes/conf.toml", "Specify path to config file.") flag.Parse() conf.LoadConfig(confFile) - oauth := &oauth2.Config{ - ClientID: conf.Conf.OIDC.ClientID, - ClientSecret: conf.Conf.OIDC.ClientSecret, - Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL}, - RedirectURL: conf.Conf.OIDC.RedirectURL, - } - - sessions := middleware.NewSessionStore(oauth, "/auth") - err := notes.Init() if err != nil { log.Fatal(err) @@ -41,7 +34,7 @@ func main() { router := http.NewServeMux() notesRouter := views.GetRoutes("/notes") - sessionRouter := sessions.Routes.GetRouter() + authRouter := auth.GetRoutes("/auth", sessions.Login) cacheExpiration, err := time.ParseDuration("24h") if err != nil { @@ -50,9 +43,10 @@ func main() { etag := middleware.NewETag("static", cacheExpiration) + router.Handle("/logout/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.HandlerFunc(sessions.Logout)))) router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound))) router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter)))) - router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", sessionRouter)))) + router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", authRouter)))) router.Handle( "/static/", middleware.LoggingMiddleware( diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 6455a12..d3b9992 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -6,43 +6,124 @@ import ( "encoding/base64" "encoding/json" "fmt" + "log" "net/http" + urls "forgejo.gwairfelin.com/max/gispatcho" "forgejo.gwairfelin.com/max/gonotes/internal/conf" "golang.org/x/oauth2" ) -func GenerateStateOAUTHCookie(w http.ResponseWriter, prefix string) string { +var myurls urls.URLs +var oauthConfig *oauth2.Config +var loginFunction func(user string, w http.ResponseWriter) string + +type userInfo struct { + preferred_username string +} + +func GetRoutes(prefix string, _loginFunction func(user string, w http.ResponseWriter) string) *http.ServeMux { + loginFunction = _loginFunction + + oauthConfig = &oauth2.Config{ + ClientID: conf.Conf.OIDC.ClientID, + ClientSecret: conf.Conf.OIDC.ClientSecret, + Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL}, + RedirectURL: conf.Conf.OIDC.RedirectURL, + } + + myurls = urls.URLs{ + Prefix: prefix, + URLs: map[string]urls.URL{ + "login": {Path: "/oauth/login/", Protocol: "GET", Handler: oauthLogin}, + "callback": {Path: "/oauth/callback/", Protocol: "GET", Handler: oauthCallback}, + }, + } + return myurls.GetRouter() +} + +func oauthLogin(w http.ResponseWriter, r *http.Request) { + log.Printf("%+v", *oauthConfig) + + oauthState := generateStateOAUTHCookie(w) + + url := oauthConfig.AuthCodeURL(oauthState) + log.Printf("Redirecting to %s", url) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func generateStateOAUTHCookie(w http.ResponseWriter) string { b := make([]byte, 16) rand.Read(b) state := base64.URLEncoding.EncodeToString(b) cookie := http.Cookie{ Name: "oauthstate", Value: state, - MaxAge: 30, Secure: true, HttpOnly: true, Path: prefix, + MaxAge: 30, Secure: true, HttpOnly: true, Path: "/auth/oauth/", } http.SetCookie(w, &cookie) return state } -func GetUserFromForgejo(code string, oauth *oauth2.Config) (string, error) { +func oauthCallback(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + log.Printf("An error occured during login: %s", err) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + log.Printf("%v", oauthState) + + if r.FormValue("state") != oauthState.Value { + log.Println("invalid oauth state") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + data, err := getUserData(r.FormValue("code")) + if err != nil { + log.Println(err.Error()) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + username, ok := data["preferred_username"] + if !ok { + log.Println("No username in auth response") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + userStr, ok := username.(string) + if !ok { + log.Println("Username not interpretable as string") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + loginFunction(userStr, w) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) +} + +func getUserData(code string) (map[string]any, error) { // Use code to get token and get user info from Google. - token, err := oauth.Exchange(context.Background(), code) + token, err := oauthConfig.Exchange(context.Background(), code) if err != nil { - return "", fmt.Errorf("code exchange wrong: %s", err.Error()) + return nil, fmt.Errorf("code exchange wrong: %s", err.Error()) } request, err := http.NewRequest("GET", conf.Conf.OIDC.UserinfoURL, nil) if err != nil { - return "", fmt.Errorf("failed to init http client for userinfo: %s", err.Error()) + return nil, fmt.Errorf("failed to init http client for userinfo: %s", err.Error()) } request.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken)) response, err := http.DefaultClient.Do(request) if err != nil { - return "", fmt.Errorf("failed getting user info: %s", err.Error()) + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) } defer response.Body.Close() @@ -50,17 +131,10 @@ func GetUserFromForgejo(code string, oauth *oauth2.Config) (string, error) { err = json.NewDecoder(response.Body).Decode(&uInf) if err != nil { - return "", fmt.Errorf("failed to parse response as json: %s", err.Error()) + return nil, fmt.Errorf("failed to parse response as json: %s", err.Error()) } - username, ok := uInf["preferred_username"] - if !ok { - return "", fmt.Errorf("no username in response: %s", err.Error()) - } - userStr, ok := username.(string) - if !ok { - return "", fmt.Errorf("username not a string: %s", err.Error()) - } + log.Printf("Contents of user data response %s", uInf) - return userStr, nil + return uInf, nil } diff --git a/internal/conf/templates/base.tmpl.html b/internal/conf/templates/base.tmpl.html index 6d776ff..8ce885a 100644 --- a/internal/conf/templates/base.tmpl.html +++ b/internal/conf/templates/base.tmpl.html @@ -64,11 +64,11 @@ {{template "navLinks" .}} {{if eq .user "anon"}}