gonotes/internal/middleware/reject_anon.go

96 lines
2.1 KiB
Go

// Middleware designed to reject requests from anon users unless from 'safe'
// IP addresses
package middleware
import (
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
"forgejo.gwairfelin.com/max/gonotes/internal/conf"
)
type netList []net.IPNet
const ipHeader = "x-forwarded-for"
func (n *netList) Contains(ip net.IP) bool {
for _, net := range *n {
if contains := net.Contains(ip); contains {
return true
}
}
return false
}
func RejectAnonMiddleware(redirect string, next http.Handler) http.Handler {
safeOriginNets := make(netList, 0, len(conf.Conf.AnonCIDRs))
for _, cidr := range conf.Conf.AnonCIDRs {
_, net, err := net.ParseCIDR(cidr)
if err != nil {
log.Printf("ignoring invalid cidr: %s", err)
continue
}
safeOriginNets = append(safeOriginNets, *net)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value(ContextKey("user")).(string)
originIP, err := getOriginIP(r)
if err != nil {
log.Printf("unable to check origin ip: %s", err)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
return
}
log.Printf("origin ip: %s", originIP)
safeOrigin := safeOriginNets.Contains(originIP)
if user == "anon" && !safeOrigin {
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
})
}
// Get the origin ip from the x-forwarded-for header, or the source of
// the request if not available
func getOriginIP(r *http.Request) (net.IP, error) {
sourceIpHeader, ok := r.Header[http.CanonicalHeaderKey(ipHeader)]
if !ok {
addrParts := strings.Split(r.RemoteAddr, ":")
if len(addrParts) == 0 {
return nil, errors.New("no source ip available")
}
ip := net.ParseIP(addrParts[0])
if ip == nil {
return nil, fmt.Errorf("ip could not be parsed: %s", addrParts[0])
}
return ip, nil
}
if len(sourceIpHeader) != 1 {
return nil, fmt.Errorf("header has more than 1 value: %s=%v", ipHeader, sourceIpHeader)
}
ips := strings.Split(sourceIpHeader[0], ",")
ip := net.ParseIP(ips[0])
if ip == nil {
return nil, fmt.Errorf("not parseable as ip: %s", ips[0])
}
return ip, nil
}