Merge pull request #238 from naturallaw777/copilot/fix-nft-regex-firewall-allowed-ports
[WIP] Fix nft regex to correctly capture allowed ports
This commit is contained in:
@@ -7,6 +7,7 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import pwd
|
import pwd
|
||||||
import re
|
import re
|
||||||
@@ -31,6 +32,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
from .config import load_config
|
from .config import load_config
|
||||||
from . import systemctl as sysctl
|
from . import systemctl as sysctl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── Constants ──────────────────────────────────────────────────────
|
# ── Constants ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
FLAKE_LOCK_PATH = "/etc/nixos/flake.lock"
|
FLAKE_LOCK_PATH = "/etc/nixos/flake.lock"
|
||||||
@@ -749,7 +752,7 @@ def _get_external_ip() -> str:
|
|||||||
def _get_listening_ports() -> dict[str, set[int]]:
|
def _get_listening_ports() -> dict[str, set[int]]:
|
||||||
"""Return sets of TCP and UDP ports that have services actively listening.
|
"""Return sets of TCP and UDP ports that have services actively listening.
|
||||||
|
|
||||||
Uses ``ss -tlnp`` for TCP and ``ss -ulnp`` for UDP. Returns a dict with
|
Uses ``ss -tln`` for TCP and ``ss -uln`` for UDP. Returns a dict with
|
||||||
keys ``"tcp"`` and ``"udp"`` whose values are sets of integer port numbers.
|
keys ``"tcp"`` and ``"udp"`` whose values are sets of integer port numbers.
|
||||||
|
|
||||||
The ``ss`` LISTEN/UNCONN output has a fixed column layout when split on
|
The ``ss`` LISTEN/UNCONN output has a fixed column layout when split on
|
||||||
@@ -761,12 +764,23 @@ def _get_listening_ports() -> dict[str, set[int]]:
|
|||||||
so only truly active listeners are returned.
|
so only truly active listeners are returned.
|
||||||
"""
|
"""
|
||||||
result: dict[str, set[int]] = {"tcp": set(), "udp": set()}
|
result: dict[str, set[int]] = {"tcp": set(), "udp": set()}
|
||||||
for proto, flag in (("tcp", "-tlnp"), ("udp", "-ulnp")):
|
|
||||||
|
def _extract_port(addr: str) -> int | None:
|
||||||
|
m = re.search(r":(\d+)$", addr)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
return int(m.group(1))
|
||||||
|
|
||||||
|
for proto, flag in (("tcp", "-tln"), ("udp", "-uln")):
|
||||||
try:
|
try:
|
||||||
proc = subprocess.run(
|
proc = subprocess.run(
|
||||||
["ss", flag],
|
["ss", flag],
|
||||||
capture_output=True, text=True, timeout=10,
|
capture_output=True, text=True, timeout=10,
|
||||||
)
|
)
|
||||||
|
logger.debug("ss %s rc=%s stderr=%r", flag, proc.returncode, proc.stderr.strip())
|
||||||
|
logger.debug("ss %s output sample: %r", flag, "\n".join(proc.stdout.splitlines()[:8]))
|
||||||
|
if proc.returncode != 0:
|
||||||
|
continue
|
||||||
for line in proc.stdout.splitlines():
|
for line in proc.stdout.splitlines():
|
||||||
parts = line.split()
|
parts = line.split()
|
||||||
if len(parts) < 5:
|
if len(parts) < 5:
|
||||||
@@ -777,20 +791,26 @@ def _get_listening_ports() -> dict[str, set[int]]:
|
|||||||
# Only process LISTEN (TCP) or UNCONN (UDP) state lines
|
# Only process LISTEN (TCP) or UNCONN (UDP) state lines
|
||||||
if parts[0] not in ("LISTEN", "UNCONN"):
|
if parts[0] not in ("LISTEN", "UNCONN"):
|
||||||
continue
|
continue
|
||||||
# Local address is always at column index 3:
|
# Typical layout:
|
||||||
# State Recv-Q Send-Q Local_Address:Port Peer_Address:Port ...
|
# State Recv-Q Send-Q Local_Address:Port Peer_Address:Port ...
|
||||||
# Formats: 0.0.0.0:443, *:443, [::]:443, 127.0.0.1:443
|
# Be defensive and fall back to scanning for the first token that
|
||||||
|
# looks like an address with a numeric port.
|
||||||
local_addr = parts[3]
|
local_addr = parts[3]
|
||||||
port_str = local_addr.rsplit(":", 1)[-1]
|
port = _extract_port(local_addr)
|
||||||
# Defensively skip wildcard port (e.g. an unbound socket showing *:*)
|
if port is None:
|
||||||
if port_str == "*":
|
for token in parts[3:]:
|
||||||
continue
|
port = _extract_port(token)
|
||||||
try:
|
if port is not None:
|
||||||
result[proto].add(int(port_str))
|
break
|
||||||
except ValueError:
|
if port is not None:
|
||||||
pass
|
result[proto].add(port)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
logger.debug(
|
||||||
|
"parsed listening ports: tcp=%s udp=%s",
|
||||||
|
sorted(result["tcp"]),
|
||||||
|
sorted(result["udp"]),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -808,13 +828,13 @@ def _get_firewall_allowed_ports() -> dict[str, set[int]]:
|
|||||||
["nft", "list", "ruleset"],
|
["nft", "list", "ruleset"],
|
||||||
capture_output=True, text=True, timeout=10,
|
capture_output=True, text=True, timeout=10,
|
||||||
)
|
)
|
||||||
|
logger.debug("nft list ruleset rc=%s stderr=%r", proc.returncode, proc.stderr.strip())
|
||||||
|
logger.debug("nft output sample: %r", "\n".join(proc.stdout.splitlines()[:12]))
|
||||||
if proc.returncode == 0:
|
if proc.returncode == 0:
|
||||||
text = proc.stdout
|
text = proc.stdout
|
||||||
# Match patterns like: tcp dport 443 accept or tcp dport { 80, 443 }
|
# Match patterns like: tcp dport 443 ... or tcp dport { 80, 443 } ...
|
||||||
for proto in ("tcp", "udp"):
|
for proto in ("tcp", "udp"):
|
||||||
for m in re.finditer(
|
for m in re.finditer(rf"{proto}\s+dport\s+\{{\s*([^}}]+?)\s*\}}", text):
|
||||||
rf'{proto}\s+dport\s+\{{?([^}};\n]+)\}}?', text
|
|
||||||
):
|
|
||||||
raw = m.group(1)
|
raw = m.group(1)
|
||||||
for token in re.split(r'[\s,]+', raw):
|
for token in re.split(r'[\s,]+', raw):
|
||||||
token = token.strip()
|
token = token.strip()
|
||||||
@@ -823,6 +843,18 @@ def _get_firewall_allowed_ports() -> dict[str, set[int]]:
|
|||||||
elif re.match(r'^(\d+)-(\d+)$', token):
|
elif re.match(r'^(\d+)-(\d+)$', token):
|
||||||
lo, hi = token.split("-")
|
lo, hi = token.split("-")
|
||||||
result[proto].update(range(int(lo), int(hi) + 1))
|
result[proto].update(range(int(lo), int(hi) + 1))
|
||||||
|
for m in re.finditer(rf"{proto}\s+dport\s+(\d+(?:-\d+)?)\b", text):
|
||||||
|
token = m.group(1)
|
||||||
|
if re.match(r'^\d+$', token):
|
||||||
|
result[proto].add(int(token))
|
||||||
|
else:
|
||||||
|
lo, hi = token.split("-")
|
||||||
|
result[proto].update(range(int(lo), int(hi) + 1))
|
||||||
|
logger.debug(
|
||||||
|
"parsed firewall ports from nft: tcp=%s udp=%s",
|
||||||
|
sorted(result["tcp"]),
|
||||||
|
sorted(result["udp"]),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -881,9 +913,9 @@ def _check_port_status(
|
|||||||
|
|
||||||
ports_set = set(ports)
|
ports_set = set(ports)
|
||||||
is_listening = any(
|
is_listening = any(
|
||||||
pt in ports_set
|
pt in listening.get(proto_key, set())
|
||||||
for proto_key in protos
|
for proto_key in protos
|
||||||
for pt in listening.get(proto_key, set())
|
for pt in ports_set
|
||||||
)
|
)
|
||||||
is_allowed = any(
|
is_allowed = any(
|
||||||
pt in allowed.get(proto_key, set())
|
pt in allowed.get(proto_key, set())
|
||||||
|
|||||||
Reference in New Issue
Block a user