commit
This commit is contained in:
@@ -0,0 +1,158 @@
|
||||
"""Helpers for streaming journalctl logs over SSH."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shlex
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
import paramiko
|
||||
|
||||
SSH_KEY_PATH = Path(os.getenv("SSH_KEY_PATH", "keys/5G-SSH-Key.pem"))
|
||||
JQ_FILTER = '.TIMESTAMP + " " + .SYSLOG_IDENTIFIER + " " + (.SUPI // "") + " " + .MESSAGE'
|
||||
|
||||
@dataclass
|
||||
class LogTarget:
|
||||
host: str
|
||||
processes: List[str]
|
||||
hostname: str | None = None
|
||||
|
||||
class JournalctlStream:
|
||||
"""Manage concurrent journalctl streams for multiple hosts."""
|
||||
|
||||
def __init__(self, targets: Iterable[LogTarget]) -> None:
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self.targets = [t for t in targets if t.processes]
|
||||
if not self.targets:
|
||||
raise ValueError("No valid log targets provided")
|
||||
if not SSH_KEY_PATH.exists():
|
||||
raise FileNotFoundError(f"SSH key not found at {SSH_KEY_PATH}")
|
||||
self._queue: "queue.Queue[dict]" = queue.Queue()
|
||||
self._stop_event = threading.Event()
|
||||
self._threads: list[threading.Thread] = []
|
||||
self._clients: dict[str, paramiko.SSHClient] = {}
|
||||
|
||||
def start(self) -> None:
|
||||
for target in self.targets:
|
||||
thread = threading.Thread(target=self._stream_host, args=(target,), daemon=True)
|
||||
thread.start()
|
||||
self._threads.append(thread)
|
||||
|
||||
def iter_events(self):
|
||||
self.start()
|
||||
finished = 0
|
||||
total = len(self.targets)
|
||||
while finished < total and not self._stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
yield {"type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
continue
|
||||
if event.get("type") == "complete":
|
||||
finished += 1
|
||||
yield event
|
||||
# Drain remaining events if any
|
||||
while not self._queue.empty():
|
||||
yield self._queue.get()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
for client in self._clients.values():
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
for thread in self._threads:
|
||||
thread.join(timeout=1)
|
||||
|
||||
# --- internal helpers ---
|
||||
|
||||
def _stream_host(self, target: LogTarget) -> None:
|
||||
filter_args = []
|
||||
for proc in target.processes:
|
||||
proc = (proc or "").strip()
|
||||
if not proc:
|
||||
continue
|
||||
filter_args.append(f"-t {shlex.quote(proc)}")
|
||||
|
||||
safe_filters = " ".join(filter_args)
|
||||
if not safe_filters:
|
||||
message = "No processes selected"
|
||||
self._logger.error("Log stream aborted for %s: %s", target.host, message)
|
||||
self._queue.put({
|
||||
"type": "error",
|
||||
"host": target.host,
|
||||
"hostname": target.hostname or target.host,
|
||||
"message": message
|
||||
})
|
||||
self._queue.put({"type": "complete", "host": target.host, "hostname": target.hostname or target.host})
|
||||
return
|
||||
|
||||
pipeline = (
|
||||
f"journalctl {safe_filters} -o json -n 50 -f "
|
||||
f"| jq -r --unbuffered '{JQ_FILTER}'"
|
||||
)
|
||||
command = f"bash -lc {shlex.quote(pipeline)}"
|
||||
self._logger.debug("Executing log command on %s: %s", target.host, pipeline)
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
try:
|
||||
client.connect(
|
||||
target.host,
|
||||
username="root",
|
||||
key_filename=str(SSH_KEY_PATH),
|
||||
look_for_keys=False,
|
||||
timeout=15,
|
||||
)
|
||||
self._clients[target.host] = client
|
||||
_, stdout, stderr = client.exec_command(command, get_pty=True)
|
||||
for line in iter(lambda: stdout.readline(), ""):
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
payload = line.rstrip("\r\n")
|
||||
if not payload:
|
||||
continue
|
||||
self._queue.put({
|
||||
"type": "log",
|
||||
"host": target.host,
|
||||
"hostname": target.hostname or target.host,
|
||||
"line": payload,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
err = stderr.read().decode(errors="ignore").strip()
|
||||
exit_status = stdout.channel.recv_exit_status()
|
||||
if exit_status not in (0, -1) or err:
|
||||
message = err or f"journalctl exited with status {exit_status}"
|
||||
self._logger.error("Log stream error on %s: %s", target.host, message)
|
||||
self._queue.put({
|
||||
"type": "error",
|
||||
"host": target.host,
|
||||
"hostname": target.hostname or target.host,
|
||||
"message": message,
|
||||
})
|
||||
except Exception as exc:
|
||||
message = str(exc)
|
||||
self._logger.exception("Log stream exception for %s: %s", target.host, message)
|
||||
self._queue.put({
|
||||
"type": "error",
|
||||
"host": target.host,
|
||||
"hostname": target.hostname or target.host,
|
||||
"message": message,
|
||||
})
|
||||
finally:
|
||||
self._queue.put({
|
||||
"type": "complete",
|
||||
"host": target.host,
|
||||
"hostname": target.hostname or target.host,
|
||||
})
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,194 @@
|
||||
"""Helpers for managing OpenVPN processes inside the container."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
VPN_CONFIG_DIR = Path(os.getenv("VPN_CONFIG_DIR", "/vpn/configs"))
|
||||
VPN_RUNTIME_DIR = Path(os.getenv("VPN_RUNTIME_DIR", "/vpn/runtime"))
|
||||
PID_FILE = VPN_RUNTIME_DIR / "openvpn.pid"
|
||||
STATE_FILE = VPN_RUNTIME_DIR / "active_vpn"
|
||||
LOG_FILE = VPN_RUNTIME_DIR / "openvpn.log"
|
||||
START_TIMEOUT = float(os.getenv("VPN_START_TIMEOUT", "15"))
|
||||
STOP_TIMEOUT = float(os.getenv("VPN_STOP_TIMEOUT", "10"))
|
||||
AUTH_DIRECTIVE = "auth-user-pass"
|
||||
|
||||
class VPNRuntimeError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_runtime_dir() -> None:
|
||||
VPN_RUNTIME_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _config_path(vpn_name: str) -> Path:
|
||||
VPN_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for suffix in (".conf", ".ovpn"):
|
||||
candidate = VPN_CONFIG_DIR / f"{vpn_name}{suffix}"
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
raise VPNRuntimeError(f"No config found for VPN '{vpn_name}' in {VPN_CONFIG_DIR}")
|
||||
|
||||
|
||||
def _prepare_auth_override(vpn_name: str, config_path: Path) -> list[str]:
|
||||
"""Return CLI args to supply a sanitized auth-user-pass file if needed."""
|
||||
try:
|
||||
lines = config_path.read_text().splitlines()
|
||||
except FileNotFoundError as exc:
|
||||
raise VPNRuntimeError(f"Config {config_path} missing: {exc}") from exc
|
||||
|
||||
auth_target: Path | None = None
|
||||
for raw in lines:
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#") or line.startswith(";"):
|
||||
continue
|
||||
if not line.lower().startswith(AUTH_DIRECTIVE):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
auth_path = parts[1].strip('"')
|
||||
candidate = Path(auth_path)
|
||||
if not candidate.is_absolute():
|
||||
candidate = config_path.parent / candidate
|
||||
auth_target = candidate
|
||||
break
|
||||
|
||||
if not auth_target:
|
||||
return []
|
||||
if not auth_target.exists():
|
||||
raise VPNRuntimeError(
|
||||
f"Auth file {auth_target} referenced by {config_path} does not exist"
|
||||
)
|
||||
|
||||
dest = VPN_RUNTIME_DIR / f"{vpn_name}.auth"
|
||||
try:
|
||||
data = auth_target.read_bytes()
|
||||
except OSError as exc:
|
||||
raise VPNRuntimeError(f"Failed to read auth file {auth_target}: {exc}") from exc
|
||||
if not data.strip():
|
||||
raise VPNRuntimeError(f"Auth file {auth_target} is empty")
|
||||
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
dest.chmod(0o600)
|
||||
except OSError as exc:
|
||||
raise VPNRuntimeError(f"Failed to stage auth file at {dest}: {exc}") from exc
|
||||
|
||||
return ["--auth-user-pass", str(dest)]
|
||||
|
||||
|
||||
def _is_pid_running(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _read_pid() -> int | None:
|
||||
if PID_FILE.exists():
|
||||
try:
|
||||
return int(PID_FILE.read_text().strip())
|
||||
except ValueError:
|
||||
PID_FILE.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
|
||||
def _write_state(vpn_name: str) -> None:
|
||||
STATE_FILE.write_text(vpn_name)
|
||||
|
||||
|
||||
def _clear_state() -> None:
|
||||
PID_FILE.unlink(missing_ok=True)
|
||||
STATE_FILE.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def get_active_vpn() -> str | None:
|
||||
pid = _read_pid()
|
||||
if not pid:
|
||||
_clear_state()
|
||||
return None
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return None
|
||||
if STATE_FILE.exists():
|
||||
return STATE_FILE.read_text().strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def stop_active_vpn() -> None:
|
||||
pid = _read_pid()
|
||||
if not pid:
|
||||
_clear_state()
|
||||
return
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return
|
||||
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
deadline = time.time() + STOP_TIMEOUT
|
||||
while time.time() < deadline:
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
# escalate
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
_clear_state()
|
||||
|
||||
|
||||
def start_vpn(vpn_name: str) -> None:
|
||||
config_path = _config_path(vpn_name)
|
||||
_ensure_runtime_dir()
|
||||
extra_args = _prepare_auth_override(vpn_name, config_path)
|
||||
stop_active_vpn()
|
||||
|
||||
cmd = [
|
||||
"openvpn",
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--daemon",
|
||||
"--writepid",
|
||||
str(PID_FILE),
|
||||
"--log",
|
||||
str(LOG_FILE),
|
||||
"--setenv",
|
||||
"VPN_NAME",
|
||||
vpn_name,
|
||||
]
|
||||
cmd.extend(extra_args)
|
||||
try:
|
||||
subprocess.run(cmd, check=True, cwd=str(config_path.parent))
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise VPNRuntimeError(f"OpenVPN failed to start for {vpn_name}: {exc}") from exc
|
||||
|
||||
deadline = time.time() + START_TIMEOUT
|
||||
while time.time() < deadline:
|
||||
pid = _read_pid()
|
||||
if pid and _is_pid_running(pid):
|
||||
_write_state(vpn_name)
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
stop_active_vpn()
|
||||
raise VPNRuntimeError(f"Timed out waiting for OpenVPN to start for {vpn_name}")
|
||||
|
||||
|
||||
def list_available_vpns() -> list[str]:
|
||||
if not VPN_CONFIG_DIR.exists():
|
||||
return []
|
||||
names: list[str] = []
|
||||
for path in sorted(VPN_CONFIG_DIR.glob("*.conf")):
|
||||
names.append(path.stem)
|
||||
for path in sorted(VPN_CONFIG_DIR.glob("*.ovpn")):
|
||||
if path.stem not in names:
|
||||
names.append(path.stem)
|
||||
return names
|
||||
Reference in New Issue
Block a user