diff --git a/cmd/server/main.go b/cmd/server/main.go index e406271..0a39c4d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,6 +8,7 @@ import ( "os" "time" + "forgejo.gwairfelin.com/max/gonotes/internal/auth" "forgejo.gwairfelin.com/max/gonotes/internal/conf" "forgejo.gwairfelin.com/max/gonotes/internal/middleware" "forgejo.gwairfelin.com/max/gonotes/internal/notes" @@ -33,6 +34,7 @@ func main() { router := http.NewServeMux() notesRouter := views.GetRoutes("/notes") + authRouter := auth.GetRoutes("/auth", sessions.Login) cacheExpiration, err := time.ParseDuration("24h") if err != nil { @@ -41,31 +43,10 @@ func main() { etag := middleware.NewETag("static", cacheExpiration) - if !conf.Conf.Production { - router.HandleFunc("/login/", func(w http.ResponseWriter, r *http.Request) { - user := r.FormValue("user") - if len(user) == 0 { - user = "anon" - } - sessions.Login(user, w) - - http.Redirect(w, r, "/notes/", http.StatusFound) - }) - } - router.Handle("/logout/", sessions.AsMiddleware( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.FormValue("user") - if len(user) == 0 { - user = "anon" - } - - sessions.Logout(w, r) - - http.Redirect(w, r, "/notes/", http.StatusFound) - }))) - + router.Handle("/logout/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.HandlerFunc(sessions.Logout)))) router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound))) router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter)))) + router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", authRouter)))) router.Handle( "/static/", middleware.LoggingMiddleware( diff --git a/go.mod b/go.mod index 94853c5..6ca59a3 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,7 @@ require ( forgejo.gwairfelin.com/max/gispatcho v0.1.2 github.com/pelletier/go-toml/v2 v2.2.3 github.com/teekennedy/goldmark-markdown v0.5.1 -) - -require ( - github.com/yuin/goldmark-meta v1.1.0 // indirect - gopkg.in/yaml.v2 v2.3.0 // indirect + github.com/yuin/goldmark-meta v1.1.0 + golang.org/x/oauth2 v0.34.0 + gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index b7c1a06..d5a5d0c 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,9 @@ github.com/yuin/goldmark-meta v1.1.0 h1:pWw+JLHGZe8Rk0EGsMVssiNb/AaPMHfSRszZeUei github.com/yuin/goldmark-meta v1.1.0/go.mod h1:U4spWENafuA7Zyg+Lj5RqK/MF+ovMYtBvXi1lBb2VP0= go.abhg.dev/goldmark/toc v0.11.0 h1:IRixVy3/yVPKvFBc37EeBPi8XLTXrtH6BYaonSjkF8o= go.abhg.dev/goldmark/toc v0.11.0/go.mod h1:XMFIoI1Sm6dwF9vKzVDOYE/g1o5BmKXghLG8q/wJNww= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go new file mode 100644 index 0000000..d3b9992 --- /dev/null +++ b/internal/auth/oauth.go @@ -0,0 +1,140 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "net/http" + + urls "forgejo.gwairfelin.com/max/gispatcho" + "forgejo.gwairfelin.com/max/gonotes/internal/conf" + "golang.org/x/oauth2" +) + +var myurls urls.URLs +var oauthConfig *oauth2.Config +var loginFunction func(user string, w http.ResponseWriter) string + +type userInfo struct { + preferred_username string +} + +func GetRoutes(prefix string, _loginFunction func(user string, w http.ResponseWriter) string) *http.ServeMux { + loginFunction = _loginFunction + + oauthConfig = &oauth2.Config{ + ClientID: conf.Conf.OIDC.ClientID, + ClientSecret: conf.Conf.OIDC.ClientSecret, + Endpoint: oauth2.Endpoint{AuthURL: conf.Conf.OIDC.AuthURL, TokenURL: conf.Conf.OIDC.TokenURL}, + RedirectURL: conf.Conf.OIDC.RedirectURL, + } + + myurls = urls.URLs{ + Prefix: prefix, + URLs: map[string]urls.URL{ + "login": {Path: "/oauth/login/", Protocol: "GET", Handler: oauthLogin}, + "callback": {Path: "/oauth/callback/", Protocol: "GET", Handler: oauthCallback}, + }, + } + return myurls.GetRouter() +} + +func oauthLogin(w http.ResponseWriter, r *http.Request) { + log.Printf("%+v", *oauthConfig) + + oauthState := generateStateOAUTHCookie(w) + + url := oauthConfig.AuthCodeURL(oauthState) + log.Printf("Redirecting to %s", url) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func generateStateOAUTHCookie(w http.ResponseWriter) string { + + b := make([]byte, 16) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + cookie := http.Cookie{ + Name: "oauthstate", Value: state, + MaxAge: 30, Secure: true, HttpOnly: true, Path: "/auth/oauth/", + } + http.SetCookie(w, &cookie) + + return state +} + +func oauthCallback(w http.ResponseWriter, r *http.Request) { + // Read oauthState from Cookie + oauthState, err := r.Cookie("oauthstate") + if err != nil { + log.Printf("An error occured during login: %s", err) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + log.Printf("%v", oauthState) + + if r.FormValue("state") != oauthState.Value { + log.Println("invalid oauth state") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + data, err := getUserData(r.FormValue("code")) + if err != nil { + log.Println(err.Error()) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + username, ok := data["preferred_username"] + if !ok { + log.Println("No username in auth response") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + userStr, ok := username.(string) + if !ok { + log.Println("Username not interpretable as string") + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return + } + + loginFunction(userStr, w) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) +} + +func getUserData(code string) (map[string]any, error) { + // Use code to get token and get user info from Google. + + token, err := oauthConfig.Exchange(context.Background(), code) + if err != nil { + return nil, fmt.Errorf("code exchange wrong: %s", err.Error()) + } + + request, err := http.NewRequest("GET", conf.Conf.OIDC.UserinfoURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to init http client for userinfo: %s", err.Error()) + } + request.Header.Set("Authorization", fmt.Sprintf("token %s", token.AccessToken)) + response, err := http.DefaultClient.Do(request) + + if err != nil { + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer response.Body.Close() + + uInf := make(map[string]any) + + err = json.NewDecoder(response.Body).Decode(&uInf) + if err != nil { + return nil, fmt.Errorf("failed to parse response as json: %s", err.Error()) + } + + log.Printf("Contents of user data response %s", uInf) + + return uInf, nil +} diff --git a/internal/conf/conf.go b/internal/conf/conf.go index a47dddd..7c3592a 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -60,6 +60,14 @@ type Config struct { NotesDir string LogAccess bool Production bool + OIDC struct { + ClientID string `toml:"client_id"` + ClientSecret string `toml:"client_secret"` + AuthURL string `toml:"auth_url"` + TokenURL string `toml:"token_url"` + RedirectURL string `toml:"redirect_url"` + UserinfoURL string `toml:"userinfo_url"` + } } var ( diff --git a/internal/conf/templates/base.tmpl.html b/internal/conf/templates/base.tmpl.html index bd936f0..8ce885a 100644 --- a/internal/conf/templates/base.tmpl.html +++ b/internal/conf/templates/base.tmpl.html @@ -62,9 +62,15 @@ {{template "navLinks" .}} + {{if eq .user "anon"}} +