// Middleware designed to reject requests from anon users unless from 'safe' // IP addresses package middleware import ( "errors" "fmt" "log" "net" "net/http" "strings" "forgejo.gwairfelin.com/max/gonotes/internal/conf" ) type netList []net.IPNet const ipHeader = "x-forwarded-for" // Check if any IPNet in the netList contains the given IP func (n *netList) Contains(ip net.IP) bool { for _, net := range *n { if contains := net.Contains(ip); contains { return true } } return false } // Redirect to redirect url any request where the user is anon and the request // does not appear to come from a safe origin func RejectAnonMiddleware(redirect string, next http.Handler) http.Handler { safeOriginNets := make(netList, 0, len(conf.Conf.AnonCIDRs)) for _, cidr := range conf.Conf.AnonCIDRs { _, net, err := net.ParseCIDR(cidr) if err != nil { log.Printf("ignoring invalid cidr: %s", err) continue } safeOriginNets = append(safeOriginNets, *net) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := r.Context().Value(ContextKey("user")).(string) originIP, err := getOriginIP(r) if err != nil { log.Printf("unable to check origin ip: %s", err) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) return } log.Printf("origin ip: %s", originIP) safeOrigin := safeOriginNets.Contains(originIP) if user == "anon" && !safeOrigin { http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) return } next.ServeHTTP(w, r) }) } // Get the origin ip from the x-forwarded-for header, or the source of // the request if not available func getOriginIP(r *http.Request) (net.IP, error) { sourceIpHeader, ok := r.Header[http.CanonicalHeaderKey(ipHeader)] if !ok { addrParts := strings.Split(r.RemoteAddr, ":") if len(addrParts) == 0 { return nil, errors.New("no source ip available") } ip := net.ParseIP(addrParts[0]) if ip == nil { return nil, fmt.Errorf("ip could not be parsed: %s", addrParts[0]) } return ip, nil } if len(sourceIpHeader) != 1 { return nil, fmt.Errorf("header has more than 1 value: %s=%v", ipHeader, sourceIpHeader) } ips := strings.Split(sourceIpHeader[0], ",") ip := net.ParseIP(ips[0]) if ip == nil { return nil, fmt.Errorf("not parseable as ip: %s", ips[0]) } return ip, nil }