Merge pull request #238 from naturallaw777/copilot/fix-nft-regex-firewall-allowed-ports

[WIP] Fix nft regex to correctly capture allowed ports
This commit is contained in:
Sovran_Systems
2026-04-14 16:45:01 -05:00
committed by GitHub

View File

@@ -7,6 +7,7 @@ import base64
import hashlib import hashlib
import hmac import hmac
import json import json
import logging
import os import os
import pwd import pwd
import re import re
@@ -31,6 +32,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
from .config import load_config from .config import load_config
from . import systemctl as sysctl from . import systemctl as sysctl
logger = logging.getLogger(__name__)
# ── Constants ────────────────────────────────────────────────────── # ── Constants ──────────────────────────────────────────────────────
FLAKE_LOCK_PATH = "/etc/nixos/flake.lock" FLAKE_LOCK_PATH = "/etc/nixos/flake.lock"
@@ -749,7 +752,7 @@ def _get_external_ip() -> str:
def _get_listening_ports() -> dict[str, set[int]]: def _get_listening_ports() -> dict[str, set[int]]:
"""Return sets of TCP and UDP ports that have services actively listening. """Return sets of TCP and UDP ports that have services actively listening.
Uses ``ss -tlnp`` for TCP and ``ss -ulnp`` for UDP. Returns a dict with Uses ``ss -tln`` for TCP and ``ss -uln`` for UDP. Returns a dict with
keys ``"tcp"`` and ``"udp"`` whose values are sets of integer port numbers. keys ``"tcp"`` and ``"udp"`` whose values are sets of integer port numbers.
The ``ss`` LISTEN/UNCONN output has a fixed column layout when split on The ``ss`` LISTEN/UNCONN output has a fixed column layout when split on
@@ -761,12 +764,23 @@ def _get_listening_ports() -> dict[str, set[int]]:
so only truly active listeners are returned. so only truly active listeners are returned.
""" """
result: dict[str, set[int]] = {"tcp": set(), "udp": set()} result: dict[str, set[int]] = {"tcp": set(), "udp": set()}
for proto, flag in (("tcp", "-tlnp"), ("udp", "-ulnp")):
def _extract_port(addr: str) -> int | None:
m = re.search(r":(\d+)$", addr)
if not m:
return None
return int(m.group(1))
for proto, flag in (("tcp", "-tln"), ("udp", "-uln")):
try: try:
proc = subprocess.run( proc = subprocess.run(
["ss", flag], ["ss", flag],
capture_output=True, text=True, timeout=10, capture_output=True, text=True, timeout=10,
) )
logger.debug("ss %s rc=%s stderr=%r", flag, proc.returncode, proc.stderr.strip())
logger.debug("ss %s output sample: %r", flag, "\n".join(proc.stdout.splitlines()[:8]))
if proc.returncode != 0:
continue
for line in proc.stdout.splitlines(): for line in proc.stdout.splitlines():
parts = line.split() parts = line.split()
if len(parts) < 5: if len(parts) < 5:
@@ -777,20 +791,26 @@ def _get_listening_ports() -> dict[str, set[int]]:
# Only process LISTEN (TCP) or UNCONN (UDP) state lines # Only process LISTEN (TCP) or UNCONN (UDP) state lines
if parts[0] not in ("LISTEN", "UNCONN"): if parts[0] not in ("LISTEN", "UNCONN"):
continue continue
# Local address is always at column index 3: # Typical layout:
# State Recv-Q Send-Q Local_Address:Port Peer_Address:Port ... # State Recv-Q Send-Q Local_Address:Port Peer_Address:Port ...
# Formats: 0.0.0.0:443, *:443, [::]:443, 127.0.0.1:443 # Be defensive and fall back to scanning for the first token that
# looks like an address with a numeric port.
local_addr = parts[3] local_addr = parts[3]
port_str = local_addr.rsplit(":", 1)[-1] port = _extract_port(local_addr)
# Defensively skip wildcard port (e.g. an unbound socket showing *:*) if port is None:
if port_str == "*": for token in parts[3:]:
continue port = _extract_port(token)
try: if port is not None:
result[proto].add(int(port_str)) break
except ValueError: if port is not None:
pass result[proto].add(port)
except Exception: except Exception:
pass pass
logger.debug(
"parsed listening ports: tcp=%s udp=%s",
sorted(result["tcp"]),
sorted(result["udp"]),
)
return result return result
@@ -808,13 +828,13 @@ def _get_firewall_allowed_ports() -> dict[str, set[int]]:
["nft", "list", "ruleset"], ["nft", "list", "ruleset"],
capture_output=True, text=True, timeout=10, capture_output=True, text=True, timeout=10,
) )
logger.debug("nft list ruleset rc=%s stderr=%r", proc.returncode, proc.stderr.strip())
logger.debug("nft output sample: %r", "\n".join(proc.stdout.splitlines()[:12]))
if proc.returncode == 0: if proc.returncode == 0:
text = proc.stdout text = proc.stdout
# Match patterns like: tcp dport 443 accept or tcp dport { 80, 443 } # Match patterns like: tcp dport 443 ... or tcp dport { 80, 443 } ...
for proto in ("tcp", "udp"): for proto in ("tcp", "udp"):
for m in re.finditer( for m in re.finditer(rf"{proto}\s+dport\s+\{{\s*([^}}]+?)\s*\}}", text):
rf'{proto}\s+dport\s+\{{?([^}};\n]+)\}}?', text
):
raw = m.group(1) raw = m.group(1)
for token in re.split(r'[\s,]+', raw): for token in re.split(r'[\s,]+', raw):
token = token.strip() token = token.strip()
@@ -823,6 +843,18 @@ def _get_firewall_allowed_ports() -> dict[str, set[int]]:
elif re.match(r'^(\d+)-(\d+)$', token): elif re.match(r'^(\d+)-(\d+)$', token):
lo, hi = token.split("-") lo, hi = token.split("-")
result[proto].update(range(int(lo), int(hi) + 1)) result[proto].update(range(int(lo), int(hi) + 1))
for m in re.finditer(rf"{proto}\s+dport\s+(\d+(?:-\d+)?)\b", text):
token = m.group(1)
if re.match(r'^\d+$', token):
result[proto].add(int(token))
else:
lo, hi = token.split("-")
result[proto].update(range(int(lo), int(hi) + 1))
logger.debug(
"parsed firewall ports from nft: tcp=%s udp=%s",
sorted(result["tcp"]),
sorted(result["udp"]),
)
return result return result
except Exception: except Exception:
pass pass
@@ -881,9 +913,9 @@ def _check_port_status(
ports_set = set(ports) ports_set = set(ports)
is_listening = any( is_listening = any(
pt in ports_set pt in listening.get(proto_key, set())
for proto_key in protos for proto_key in protos
for pt in listening.get(proto_key, set()) for pt in ports_set
) )
is_allowed = any( is_allowed = any(
pt in allowed.get(proto_key, set()) pt in allowed.get(proto_key, set())