diff --git a/cmd/server/main.go b/cmd/server/main.go index 7f6d173..8eb6959 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -51,7 +51,19 @@ func main() { etag := middleware.NewETag("static", cacheExpiration) router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound))) - router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter)))) + router.Handle( + "/notes/", + sessions.AsMiddleware( + middleware.LoggingMiddleware( + middleware.RejectAnonMiddleware( + "/auth/login/", + http.StripPrefix( + "/notes", notesRouter, + ), + ), + ), + ), + ) router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", sessionRouter)))) router.Handle( "/static/", diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 7c3592a..9b4cd07 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -68,6 +68,7 @@ type Config struct { RedirectURL string `toml:"redirect_url"` UserinfoURL string `toml:"userinfo_url"` } + AnonCIDRs []string `toml:"anon_networks"` } var ( diff --git a/internal/middleware/reject_anon.go b/internal/middleware/reject_anon.go new file mode 100644 index 0000000..c5da2a9 --- /dev/null +++ b/internal/middleware/reject_anon.go @@ -0,0 +1,96 @@ +// Middleware designed to reject requests from anon users unless from 'safe' +// IP addresses + +package middleware + +import ( + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + + "forgejo.gwairfelin.com/max/gonotes/internal/conf" +) + +type netList []net.IPNet + +const ipHeader = "x-forwarded-for" + +func (n *netList) Contains(ip net.IP) bool { + for _, net := range *n { + if contains := net.Contains(ip); contains { + return true + } + } + return false +} + +func RejectAnonMiddleware(redirect string, next http.Handler) http.Handler { + safeOriginNets := make(netList, 0, len(conf.Conf.AnonCIDRs)) + + for _, cidr := range conf.Conf.AnonCIDRs { + _, net, err := net.ParseCIDR(cidr) + + if err != nil { + log.Printf("ignoring invalid cidr: %s", err) + continue + } + + safeOriginNets = append(safeOriginNets, *net) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(ContextKey("user")).(string) + + originIP, err := getOriginIP(r) + + if err != nil { + log.Printf("unable to check origin ip: %s", err) + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) + return + } + + log.Printf("origin ip: %s", originIP) + safeOrigin := safeOriginNets.Contains(originIP) + + if user == "anon" && !safeOrigin { + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) + return + } + next.ServeHTTP(w, r) + }) +} + +// Get the origin ip from the x-forwarded-for header, or the source of +// the request if not available +func getOriginIP(r *http.Request) (net.IP, error) { + sourceIpHeader, ok := r.Header[http.CanonicalHeaderKey(ipHeader)] + + if !ok { + addrParts := strings.Split(r.RemoteAddr, ":") + + if len(addrParts) == 0 { + return nil, errors.New("no source ip available") + } + + ip := net.ParseIP(addrParts[0]) + if ip == nil { + return nil, fmt.Errorf("ip could not be parsed: %s", addrParts[0]) + } + return ip, nil + } + + if len(sourceIpHeader) != 1 { + return nil, fmt.Errorf("header has more than 1 value: %s=%v", ipHeader, sourceIpHeader) + } + + ips := strings.Split(sourceIpHeader[0], ",") + ip := net.ParseIP(ips[0]) + if ip == nil { + return nil, fmt.Errorf("not parseable as ip: %s", ips[0]) + } + + return ip, nil +}