commit
This commit is contained in:
@@ -0,0 +1,194 @@
|
||||
"""Helpers for managing OpenVPN processes inside the container."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
VPN_CONFIG_DIR = Path(os.getenv("VPN_CONFIG_DIR", "/vpn/configs"))
|
||||
VPN_RUNTIME_DIR = Path(os.getenv("VPN_RUNTIME_DIR", "/vpn/runtime"))
|
||||
PID_FILE = VPN_RUNTIME_DIR / "openvpn.pid"
|
||||
STATE_FILE = VPN_RUNTIME_DIR / "active_vpn"
|
||||
LOG_FILE = VPN_RUNTIME_DIR / "openvpn.log"
|
||||
START_TIMEOUT = float(os.getenv("VPN_START_TIMEOUT", "15"))
|
||||
STOP_TIMEOUT = float(os.getenv("VPN_STOP_TIMEOUT", "10"))
|
||||
AUTH_DIRECTIVE = "auth-user-pass"
|
||||
|
||||
class VPNRuntimeError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_runtime_dir() -> None:
|
||||
VPN_RUNTIME_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _config_path(vpn_name: str) -> Path:
|
||||
VPN_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
for suffix in (".conf", ".ovpn"):
|
||||
candidate = VPN_CONFIG_DIR / f"{vpn_name}{suffix}"
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
raise VPNRuntimeError(f"No config found for VPN '{vpn_name}' in {VPN_CONFIG_DIR}")
|
||||
|
||||
|
||||
def _prepare_auth_override(vpn_name: str, config_path: Path) -> list[str]:
|
||||
"""Return CLI args to supply a sanitized auth-user-pass file if needed."""
|
||||
try:
|
||||
lines = config_path.read_text().splitlines()
|
||||
except FileNotFoundError as exc:
|
||||
raise VPNRuntimeError(f"Config {config_path} missing: {exc}") from exc
|
||||
|
||||
auth_target: Path | None = None
|
||||
for raw in lines:
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#") or line.startswith(";"):
|
||||
continue
|
||||
if not line.lower().startswith(AUTH_DIRECTIVE):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
auth_path = parts[1].strip('"')
|
||||
candidate = Path(auth_path)
|
||||
if not candidate.is_absolute():
|
||||
candidate = config_path.parent / candidate
|
||||
auth_target = candidate
|
||||
break
|
||||
|
||||
if not auth_target:
|
||||
return []
|
||||
if not auth_target.exists():
|
||||
raise VPNRuntimeError(
|
||||
f"Auth file {auth_target} referenced by {config_path} does not exist"
|
||||
)
|
||||
|
||||
dest = VPN_RUNTIME_DIR / f"{vpn_name}.auth"
|
||||
try:
|
||||
data = auth_target.read_bytes()
|
||||
except OSError as exc:
|
||||
raise VPNRuntimeError(f"Failed to read auth file {auth_target}: {exc}") from exc
|
||||
if not data.strip():
|
||||
raise VPNRuntimeError(f"Auth file {auth_target} is empty")
|
||||
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
dest.chmod(0o600)
|
||||
except OSError as exc:
|
||||
raise VPNRuntimeError(f"Failed to stage auth file at {dest}: {exc}") from exc
|
||||
|
||||
return ["--auth-user-pass", str(dest)]
|
||||
|
||||
|
||||
def _is_pid_running(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _read_pid() -> int | None:
|
||||
if PID_FILE.exists():
|
||||
try:
|
||||
return int(PID_FILE.read_text().strip())
|
||||
except ValueError:
|
||||
PID_FILE.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
|
||||
def _write_state(vpn_name: str) -> None:
|
||||
STATE_FILE.write_text(vpn_name)
|
||||
|
||||
|
||||
def _clear_state() -> None:
|
||||
PID_FILE.unlink(missing_ok=True)
|
||||
STATE_FILE.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def get_active_vpn() -> str | None:
|
||||
pid = _read_pid()
|
||||
if not pid:
|
||||
_clear_state()
|
||||
return None
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return None
|
||||
if STATE_FILE.exists():
|
||||
return STATE_FILE.read_text().strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def stop_active_vpn() -> None:
|
||||
pid = _read_pid()
|
||||
if not pid:
|
||||
_clear_state()
|
||||
return
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return
|
||||
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
deadline = time.time() + STOP_TIMEOUT
|
||||
while time.time() < deadline:
|
||||
if not _is_pid_running(pid):
|
||||
_clear_state()
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
# escalate
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
_clear_state()
|
||||
|
||||
|
||||
def start_vpn(vpn_name: str) -> None:
|
||||
config_path = _config_path(vpn_name)
|
||||
_ensure_runtime_dir()
|
||||
extra_args = _prepare_auth_override(vpn_name, config_path)
|
||||
stop_active_vpn()
|
||||
|
||||
cmd = [
|
||||
"openvpn",
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--daemon",
|
||||
"--writepid",
|
||||
str(PID_FILE),
|
||||
"--log",
|
||||
str(LOG_FILE),
|
||||
"--setenv",
|
||||
"VPN_NAME",
|
||||
vpn_name,
|
||||
]
|
||||
cmd.extend(extra_args)
|
||||
try:
|
||||
subprocess.run(cmd, check=True, cwd=str(config_path.parent))
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise VPNRuntimeError(f"OpenVPN failed to start for {vpn_name}: {exc}") from exc
|
||||
|
||||
deadline = time.time() + START_TIMEOUT
|
||||
while time.time() < deadline:
|
||||
pid = _read_pid()
|
||||
if pid and _is_pid_running(pid):
|
||||
_write_state(vpn_name)
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
stop_active_vpn()
|
||||
raise VPNRuntimeError(f"Timed out waiting for OpenVPN to start for {vpn_name}")
|
||||
|
||||
|
||||
def list_available_vpns() -> list[str]:
|
||||
if not VPN_CONFIG_DIR.exists():
|
||||
return []
|
||||
names: list[str] = []
|
||||
for path in sorted(VPN_CONFIG_DIR.glob("*.conf")):
|
||||
names.append(path.stem)
|
||||
for path in sorted(VPN_CONFIG_DIR.glob("*.ovpn")):
|
||||
if path.stem not in names:
|
||||
names.append(path.stem)
|
||||
return names
|
||||
Reference in New Issue
Block a user