Files
server/opt/psa/admin/sbin/modules/firewall/ipsets
2026-01-07 20:52:11 +01:00

358 lines
14 KiB
Python
Executable File

#!/usr/local/psa/bin/py3-python -IS
""" ipset management for country filtering in firewall. """
import argparse
import ipaddress
import json
import logging
import os
import subprocess
import sys
import textwrap
log = logging.getLogger('ipsets')
SBIN_D = os.path.dirname(os.path.abspath(__file__))
VAR_D = "/usr/local/psa/var/modules/firewall"
DATA_SOURCE_BIN_D = os.path.join(SBIN_D, 'geoip')
DATA_SOURCE_VAR_D = os.path.join(VAR_D, 'geoip')
SETTINGS_PATH = os.path.join(DATA_SOURCE_VAR_D, 'settings.json')
IPSET_PREFIX = "plesk-ip"
def set_up_logging(verbosity):
""" Set up logging based on --verbose count and PLESK_DEBUG environment. """
verbosity = verbosity or 0
level = {
0: logging.CRITICAL,
1: logging.ERROR,
2: logging.WARNING,
3: logging.INFO,
4: logging.DEBUG,
}.get(verbosity, logging.CRITICAL)
if verbosity >= 4 or os.getenv('PLESK_DEBUG'):
level = logging.DEBUG
logging.basicConfig(level=level, format='[%(asctime)s] %(levelname)8s %(message)s')
def parse_args():
epilog = f"""\
environment variables:
DOWNLOAD_TIMEOUT Data source download timeout, seconds
LICENSE_KEY Data source license key (e.g. for 'maxmind')
PLESK_DEBUG Set logging verbosity to maximum
data source contract:
Each --data-source value is an executable script with the following commands:
--exists Returns 0 only when the GeoIP data exists locally
(i.e. previous --fetch was successful).
--fetch Fetches GeoIP data from a remote source, preprocesses it,
and stores it locally. May use and store additional
environment variables, such as LICENSE_KEY. Such variables
may be absent on subsequent calls. Store data under
{DATA_SOURCE_VAR_D}/$data_source.d .
Avoid clobbering data on upstream errors.
--list ZZ Prints IP ranges or CIDR networks for both IPv4 and IPv6,
which are mapped to the country code ZZ, each on a separate
line. Order does not matter. Should use only local data,
but may use remote data (not recommended). Output examples:
127.0.0.0/8
192.0.0.0-192.0.0.255
fe80::/10
2001:db8::-2001:db8:ffff:ffff:ffff:ffff:ffff:ffff
"""
parser = argparse.ArgumentParser(description="Manage ipsets for country filtering in the firewall",
epilog=textwrap.dedent(epilog),
formatter_class=argparse.RawDescriptionHelpFormatter)
commands = parser.add_mutually_exclusive_group(required=True)
commands.add_argument('--configure', action='store_true',
help="Set up country ipsets. Create local GeoIP DB if missing, "
"persist settings, recreate country ipsets.")
commands.add_argument('--update', action='store_true',
help="Update local GeoIP DB from a remote source, then recreate all "
"country ipsets. Use from a cron job.")
commands.add_argument('--recreate', action='store_true',
help="Create missing and remove unused country ipsets. "
"Use from a firewall script.")
parser.add_argument('-v', '--verbose', action='count', default=0,
help="Increase logging verbosity, can be specified multiple times.")
parser.add_argument('-f', '--force', action='store_true',
help="Recreate all country ipsets instead of only missing and extra ones. "
"With --configure will also recreate local GeoIP DB and "
"update its settings.")
parser.add_argument('-d', '--data-source', metavar='NAME', required=True, type=type_data_source,
help="Data source name. Each data source is a script under "
f"{DATA_SOURCE_BIN_D} or {DATA_SOURCE_VAR_D}, e.g. 'maxmind'.")
parser.add_argument('-c', '--countries', nargs='*', metavar='ZZ', type=type_country_code,
help="List of 2-letter ISO 3166 country codes.")
args = parser.parse_args()
return args
def type_data_source(data_source):
""" Type caster and checker for --data-source. """
for data_source_d in (DATA_SOURCE_BIN_D, DATA_SOURCE_VAR_D):
path = os.path.join(data_source_d, data_source)
if os.access(path, os.X_OK):
return path
raise argparse.ArgumentTypeError(f"Unsupported data source: {data_source!r}")
def type_country_code(code):
""" Type caster and checker for --countries. """
if len(code) == 2 and code.isalpha() and code.isupper():
return code
raise argparse.ArgumentTypeError(f"Not a 2-letter ISO 3166 country code: {code!r}")
def log_geoip_data_dir(data_source):
""" Just logs expected data source local storage directory (by convention). """
data_source = os.path.basename(data_source)
log.debug("Data directory for %r data source is expected to be %r",
data_source, os.path.join(DATA_SOURCE_VAR_D, data_source + ".d"))
def has_geoip_data(data_source):
""" Returns True if GeoIP data is already fetched. """
log.debug("Checking for GeoIP data existence via %r", data_source)
return subprocess.call([data_source, '--exists']) == 0
def fetch_geoip_data(data_source):
""" Refetches GeoIP data from a remote source. """
log.info("Fetching GeoIP data via %r", data_source)
subprocess.check_call([data_source, '--fetch'])
def list_geoip_data(data_source, country_code):
""" Lists GeoIP data for the country_code (assuming it is fetched).
Data is a list of IP ranges or CIDR networks for both IPv4 and IPv6.
"""
log.debug("Listing GeoIP data for %r via %r", country_code, data_source)
data = subprocess.check_output([data_source, '--list', country_code], universal_newlines=True)
return data.split()
def geoip_data_to_networks(entries):
""" Generator of IPv4Network and IPv6Network objects from a list of ranges or networks.
>>> list(geoip_data_to_networks(['10.0.0.0/24', 'fe80::/10']))
[IPv4Network('10.0.0.0/24'), IPv6Network('fe80::/10')]
>>> list(geoip_data_to_networks(['10.0.0.0-10.0.0.19', '::-::3']))
[IPv4Network('10.0.0.0/28'), IPv4Network('10.0.0.16/30'), IPv6Network('::/126')]
>>> list(geoip_data_to_networks(['127.0.0.1', '::1']))
[IPv4Network('127.0.0.1/32'), IPv6Network('::1/128')]
>>> list(geoip_data_to_networks(['invalid']))
Traceback (most recent call last):
...
ValueError: 'invalid' does not appear to be an IPv4 or IPv6 network
>>> list(geoip_data_to_networks(['from-to']))
Traceback (most recent call last):
...
ValueError: 'from' does not appear to be an IPv4 or IPv6 address
"""
for entry in entries:
if '-' in entry:
str_from, str_to = entry.split('-', maxsplit=1)
ip_from, ip_to = ipaddress.ip_address(str_from), ipaddress.ip_address(str_to)
yield from ipaddress.summarize_address_range(ip_from, ip_to)
else:
yield ipaddress.ip_network(entry)
def list_existing_ipset_names():
""" Lists ipsets from the system. """
log.debug("Listing existing ipset names from system")
ipsets = subprocess.check_output(["ipset", "list", "-name"], universal_newlines=True).split()
log.debug("Got ipset names: %r", ipsets)
return ipsets
def round_to_power_of_2(x):
""" Returns value rounded to the next nearest non-negative power of 2.
>>> round_to_power_of_2(0)
1
>>> round_to_power_of_2(1)
1
>>> round_to_power_of_2(32)
32
>>> round_to_power_of_2(1000)
1024
"""
return 2 ** (x - 1).bit_length() if x >= 1 else 1
def create_ipset(ipset_name, ip_version, num_elements=0):
""" Creates ipset ipset_name for ip_version with num_elements estimate. """
# Account for possible growth due to updates, use a value that will not change often
num_elements = round_to_power_of_2(int(num_elements * 1.5))
maxelem_args = ["maxelem", str(num_elements)] if num_elements > 65536 else []
family_args = ["family", "inet" if str(ip_version) != '6' else "inet6"]
cmd = ["ipset", "create", ipset_name, "hash:net", "-exist"] + family_args + maxelem_args
try:
log.debug("Creating %r ipset: %r", ipset_name, cmd)
subprocess.check_call(cmd)
except Exception as ex:
log.warning("Failed to create %r ipset from the first try, possibly 'maxelem' changed, "
"will try recreating: %s",
ipset_name, ex)
try:
destroy_ipset(ipset_name)
except Exception as ex:
log.debug("Destroying %r ipset failed, likely due to existing references", ipset_name)
raise RuntimeError(f"Cannot recreate ipset {ipset_name!r}: {ex} "
"Try stopping the plesk-firewall.service first.") from ex
log.debug("Creating new %r ipset: %r", ipset_name, cmd)
subprocess.check_call(cmd)
def destroy_ipset(ipset_name):
""" Destroys ipset_name. This will fail if it is referenced by any iptables rules. """
log.debug("Destroying %r ipset", ipset_name)
subprocess.check_call(["ipset", "destroy", ipset_name])
def update_ipset(ipset_name, networks):
""" Replaces networks in ipset_name. """
stdin = "\n".join([f"flush {ipset_name}"] + [f"add {ipset_name} {net}" for net in networks])
log.debug("Updating %r ipset networks, %d entries", ipset_name, len(networks))
subprocess.run(["ipset", "restore"], check=True, universal_newlines=True, input=stdin)
def ipset_name(country_code, ip_version):
""" Returns ipset name for the country_code and ip_version (4 or 6). """
return IPSET_PREFIX + str(ip_version) + "-" + country_code
def recreate_ipsets(data_source, countries, recreate_all=False):
""" Recreates ipsets for the countries, using data_source.
By default, only missing ipsets are created and unused are removed.
If recreate_all, all ipsets are recreated.
"""
existing_ipsets = set(list_existing_ipset_names())
log.debug("Checking for missing ipsets (recreate_all=%r)", recreate_all)
required_ipsets = set()
for country_code in countries:
v4_name, v6_name = ipset_name(country_code, 4), ipset_name(country_code, 6)
required_ipsets.add(v4_name)
required_ipsets.add(v6_name)
if not recreate_all and v4_name in existing_ipsets and v6_name in existing_ipsets:
log.debug("Skip recreating already existing ipsets for %r country: %r, %r",
country_code, v4_name, v6_name)
continue
log.info("Creating and populating ipsets for %r country: %r, %r",
country_code, v4_name, v6_name)
v4_nets, v6_nets = [], []
for net in geoip_data_to_networks(list_geoip_data(data_source, country_code)):
if net.version == 4:
v4_nets.append(net)
elif net.version == 6:
v6_nets.append(net)
else:
raise RuntimeError(f"Network {net} is neither IPv4 nor IPv6")
create_ipset(v4_name, 4, len(v4_nets))
create_ipset(v6_name, 6, len(v6_nets))
update_ipset(v4_name, v4_nets)
update_ipset(v6_name, v6_nets)
log.debug("Checking for unused ipsets")
for name in existing_ipsets:
try:
if name.startswith(IPSET_PREFIX) and name not in required_ipsets:
log.info("Destroying unused ipset: %r", name)
destroy_ipset(name)
except Exception as ex:
log.warning("Cannot remove ipset %r, will try next time: %s", name, ex)
def store_settings(countries):
""" Stores settings for subsequent calls. """
log.debug("Storing settings into %r", SETTINGS_PATH)
data = {
'countries': sorted(countries),
}
os.makedirs(os.path.dirname(SETTINGS_PATH), 0o755, exist_ok=True)
with open(SETTINGS_PATH, 'w') as fd:
json.dump(data, fd)
fd.write("\n")
def fetch_settings():
""" Fetches previously stored settings. """
log.debug("Fetching settings from %r", SETTINGS_PATH)
try:
with open(SETTINGS_PATH, 'r') as fd:
data = json.load(fd)
log.debug("Fetched settings: %r", data)
return data['countries']
except Exception as ex:
raise RuntimeError(f"Cannot read persisted settings from {SETTINGS_PATH!r}: {ex}") from ex
def configure(data_source, countries, recreate_all=False):
""" Sets up countries ipsets from the data_source. Stores settings (countries, for data source). """
if recreate_all or not has_geoip_data(data_source):
fetch_geoip_data(data_source)
countries = countries or []
store_settings(countries)
recreate_ipsets(data_source, countries, recreate_all)
def update(data_source, countries):
""" Updates data from the data_source, then updates countries ipsets. """
fetch_geoip_data(data_source)
if countries is None:
countries = fetch_settings()
recreate_ipsets(data_source, countries, recreate_all=True)
def recreate(data_source, countries, recreate_all=False):
""" Recreates missing countries ipsets and removes unused ones, uses data from the data_source. """
if countries is None:
countries = fetch_settings()
recreate_ipsets(data_source, countries, recreate_all)
def main():
args = parse_args()
set_up_logging(args.verbose)
log.debug("Options: %s", args)
log_geoip_data_dir(args.data_source)
if args.configure:
configure(args.data_source, args.countries, args.force)
elif args.update:
update(args.data_source, args.countries)
elif args.recreate:
recreate(args.data_source, args.countries, args.force)
if __name__ == '__main__':
try:
main()
except Exception as ex:
print(f"{ex}", file=sys.stderr)
log.error("%s", ex)
log.debug("This exception happened at:", exc_info=sys.exc_info())
sys.exit(1)
# vim: ft=python