Add reject anon middleware
This commit is contained in:
parent
a1c5827641
commit
63405b6dc2
1 changed files with 99 additions and 0 deletions
99
internal/middleware/reject_anon.go
Normal file
99
internal/middleware/reject_anon.go
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
// Middleware designed to reject requests from anon users unless from 'safe'
|
||||||
|
// IP addresses
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type netList []net.IPNet
|
||||||
|
|
||||||
|
var safeCIDRs = [...]string{"192.168.0.0/23", "10.0.0.0/24", "2001:8b0:f70:546d::/64"}
|
||||||
|
|
||||||
|
var safeOriginNets netList
|
||||||
|
|
||||||
|
const ipHeader = "x-forwarded-for"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
safeOriginNets = make([]net.IPNet, 0, len(safeCIDRs))
|
||||||
|
for _, cidr := range safeCIDRs {
|
||||||
|
_, net, err := net.ParseCIDR(cidr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("ignoring invalid cidr: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
safeOriginNets = append(safeOriginNets, *net)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *netList) Contains(ip net.IP) bool {
|
||||||
|
for _, net := range *n {
|
||||||
|
if contains := net.Contains(ip); contains {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func RejectAnonMiddleware(redirect string, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := r.Context().Value(ContextKey("user")).(string)
|
||||||
|
|
||||||
|
originIP, err := getOriginIP(r)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("unable to check origin ip: %s", err)
|
||||||
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("origin ip: %s", originIP)
|
||||||
|
safeOrigin := safeOriginNets.Contains(originIP)
|
||||||
|
|
||||||
|
if user == "anon" && !safeOrigin {
|
||||||
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the origin ip from the x-forwarded-for header, or the source of
|
||||||
|
// the request if not available
|
||||||
|
func getOriginIP(r *http.Request) (net.IP, error) {
|
||||||
|
sourceIpHeader, ok := r.Header[http.CanonicalHeaderKey(ipHeader)]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
addrParts := strings.Split(r.RemoteAddr, ":")
|
||||||
|
|
||||||
|
if len(addrParts) == 0 {
|
||||||
|
return nil, errors.New("no source ip available")
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(addrParts[0])
|
||||||
|
if ip == nil {
|
||||||
|
return nil, fmt.Errorf("ip could not be parsed: %s", addrParts[0])
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sourceIpHeader) != 1 {
|
||||||
|
return nil, fmt.Errorf("header has more than 1 value: %s=%v", ipHeader, sourceIpHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
ips := strings.Split(sourceIpHeader[0], ",")
|
||||||
|
ip := net.ParseIP(ips[0])
|
||||||
|
if ip == nil {
|
||||||
|
return nil, fmt.Errorf("not parseable as ip: %s", ips[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue