diff --git a/cmd/server/main.go b/cmd/server/main.go index 8eb6959..7f6d173 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -51,19 +51,7 @@ func main() { etag := middleware.NewETag("static", cacheExpiration) router.Handle("/", middleware.LoggingMiddleware(http.RedirectHandler("/notes/", http.StatusFound))) - router.Handle( - "/notes/", - sessions.AsMiddleware( - middleware.LoggingMiddleware( - middleware.RejectAnonMiddleware( - "/auth/login/", - http.StripPrefix( - "/notes", notesRouter, - ), - ), - ), - ), - ) + router.Handle("/notes/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/notes", notesRouter)))) router.Handle("/auth/", sessions.AsMiddleware(middleware.LoggingMiddleware(http.StripPrefix("/auth", sessionRouter)))) router.Handle( "/static/", diff --git a/internal/conf/conf.go b/internal/conf/conf.go index 9b4cd07..7c3592a 100644 --- a/internal/conf/conf.go +++ b/internal/conf/conf.go @@ -68,7 +68,6 @@ type Config struct { RedirectURL string `toml:"redirect_url"` UserinfoURL string `toml:"userinfo_url"` } - AnonCIDRs []string `toml:"anon_networks"` } var ( diff --git a/internal/middleware/reject_anon.go b/internal/middleware/reject_anon.go deleted file mode 100644 index c5da2a9..0000000 --- a/internal/middleware/reject_anon.go +++ /dev/null @@ -1,96 +0,0 @@ -// Middleware designed to reject requests from anon users unless from 'safe' -// IP addresses - -package middleware - -import ( - "errors" - "fmt" - "log" - "net" - "net/http" - "strings" - - "forgejo.gwairfelin.com/max/gonotes/internal/conf" -) - -type netList []net.IPNet - -const ipHeader = "x-forwarded-for" - -func (n *netList) Contains(ip net.IP) bool { - for _, net := range *n { - if contains := net.Contains(ip); contains { - return true - } - } - return false -} - -func RejectAnonMiddleware(redirect string, next http.Handler) http.Handler { - safeOriginNets := make(netList, 0, len(conf.Conf.AnonCIDRs)) - - for _, cidr := range conf.Conf.AnonCIDRs { - _, net, err := net.ParseCIDR(cidr) - - if err != nil { - log.Printf("ignoring invalid cidr: %s", err) - continue - } - - safeOriginNets = append(safeOriginNets, *net) - } - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(ContextKey("user")).(string) - - originIP, err := getOriginIP(r) - - if err != nil { - log.Printf("unable to check origin ip: %s", err) - http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) - return - } - - log.Printf("origin ip: %s", originIP) - safeOrigin := safeOriginNets.Contains(originIP) - - if user == "anon" && !safeOrigin { - http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) - return - } - next.ServeHTTP(w, r) - }) -} - -// Get the origin ip from the x-forwarded-for header, or the source of -// the request if not available -func getOriginIP(r *http.Request) (net.IP, error) { - sourceIpHeader, ok := r.Header[http.CanonicalHeaderKey(ipHeader)] - - if !ok { - addrParts := strings.Split(r.RemoteAddr, ":") - - if len(addrParts) == 0 { - return nil, errors.New("no source ip available") - } - - ip := net.ParseIP(addrParts[0]) - if ip == nil { - return nil, fmt.Errorf("ip could not be parsed: %s", addrParts[0]) - } - return ip, nil - } - - if len(sourceIpHeader) != 1 { - return nil, fmt.Errorf("header has more than 1 value: %s=%v", ipHeader, sourceIpHeader) - } - - ips := strings.Split(sourceIpHeader[0], ",") - ip := net.ParseIP(ips[0]) - if ip == nil { - return nil, fmt.Errorf("not parseable as ip: %s", ips[0]) - } - - return ip, nil -}